In [38]:
import time
import os
from collections import defaultdict
import gc

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch import transpose as t
from torch import inverse as inv
from torch import mm,solve,matmul
from tqdm import tqdm
from sklearn.metrics import roc_auc_score

### Some basic settings and hyperparameters

In [39]:
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

dataset = 'ml-1m'
PATH = f'../data/'
COLD_USER_THRESHOLD = 30
batch_size = 1024
embedding_dim = 10
device = torch.device('cuda:0')
lr = 1e-3
num_epochs = 100
overfit_patience = 2
exp_id=0

In [40]:
train_df=pd.read_csv(PATH+'train_df.csv')
val_df = pd.read_csv(PATH+f'valid_df.csv')
test_df = pd.read_csv(PATH+f'test_df.csv')

# dataframe->pytorch dataset
train_dataset = CTR_Dataset(train_df)
val_dataset = CTR_Dataset(val_df)
test_dataset = CTR_Dataset(test_df)
num_fields = train_dataset.num_fields
num_features = 1+max([x.x_id.max().item() for x in [train_dataset, val_dataset, test_dataset]])


In [45]:
val_df_support = val_df.groupby('uid',as_index=False).apply(lambda x: x[:COLD_USER_THRESHOLD] if len(x)>COLD_USER_THRESHOLD else x[:-1])
val_df_query = val_df.groupby('uid',as_index=False).apply(lambda x: x[COLD_USER_THRESHOLD:] if len(x)>COLD_USER_THRESHOLD else x[-1:])

In [48]:
val_df_support.to_csv('../data/val_df_support.csv', index=False)

In [49]:
val_df_query.to_csv('../data/val_df_query.csv', index=False)

In [50]:
for i in range(1,COLD_USER_THRESHOLD+1):
    test_support_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[:i] if len(x)>i else x[:0])
    test_query_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[i:] if len(x)>i else x[:0])
    test_support_set.to_csv(f'../data/test/test_df_support_{i}.csv', index=False)
    test_query_set.to_csv(f'../data/test/test_df_query_{i}.csv', index=False)

### Define dataset class and some helper functions 

In [34]:
class CTR_Dataset(Dataset):
    def __init__(self, data_df):
        data_x_arr = data_df.drop(columns=['is_click']).values
        self.num_fields = data_x_arr.shape[1]//2 - 1
        self.x_id = torch.LongTensor(data_x_arr[:,1:self.num_fields+1])
        self.x_value = torch.Tensor(data_x_arr[:,self.num_fields+2:])
        self.y = torch.Tensor(data_df['is_click'].values)

    def __getitem__(self, idx):
        return self.x_id[idx], self.x_value[idx], self.y[idx]

    def __len__(self):
        return self.x_id.shape[0]

class QueryWithSupportDataset(Dataset):
    def __init__(self, data_df, train_support_df, COLD_USER_THRESHOLD):
        self.data_x_arr = data_df.drop(columns=['is_click']).values
        self.num_fields = self.data_x_arr.shape[1]//2-1
        self.x_id = torch.LongTensor(self.data_x_arr[:,1:self.num_fields+1])
        self.x_value = torch.Tensor(self.data_x_arr[:,self.num_fields+2:])
        self.y = torch.Tensor(data_df['is_click'].values)
        self.train_support_df = train_support_df
        self.COLD_USER_THRESHOLD = COLD_USER_THRESHOLD

    def __getitem__(self, idx):
        uid=self.data_x_arr[idx][0].item()
        df = self.train_support_df[self.train_support_df['uid']==uid]
        data_x_arr = df.drop(columns=['is_click']).values
        x_id_support_arr = data_x_arr[:,1:self.num_fields+1]
        x_val_support_arr = data_x_arr[:,self.num_fields+2:]
        y_support_arr = df['is_click'].values
        if x_id_support_arr.shape[0]<self.COLD_USER_THRESHOLD:
            x_id_support_arr_paddding = np.array([[0]*self.num_fields]*(
                self.COLD_USER_THRESHOLD-x_id_support_arr[:self.COLD_USER_THRESHOLD].shape[0]))
            x_id_support_arr = np.concatenate([x_id_support_arr,x_id_support_arr_paddding],axis=0)
            x_val_support_arr_paddding = np.array([[0]*self.num_fields]*(
                self.COLD_USER_THRESHOLD-x_val_support_arr[:self.COLD_USER_THRESHOLD].shape[0]))
            x_val_support_arr = np.concatenate([x_val_support_arr,x_val_support_arr_paddding],axis=0)
            y_support_arr_padding =  np.array([-1]*(
                self.COLD_USER_THRESHOLD-y_support_arr[:self.COLD_USER_THRESHOLD].shape[0]))
            y_support_arr = np.concatenate([y_support_arr,y_support_arr_padding],axis=0)
        x_id_support = torch.LongTensor(x_id_support_arr)
        x_val_support = torch.Tensor(x_val_support_arr)
        y_support = torch.Tensor(y_support_arr)
        return self.x_id[idx], self.x_value[idx], self.y[idx], [x_id_support,x_val_support,y_support]

    def __len__(self):
        return self.x_id.shape[0]

def val(model, val_dataloader):
    model.eval()
    running_loss = 0
    criterion = torch.nn.BCEWithLogitsLoss()
    pred_arr = np.array([])
    label_arr = np.array([])
    with torch.no_grad():
        for itr, batch in tqdm(enumerate(val_dataloader)):
            batch = [[e.to(device) for e in item] if isinstance(item, list) else item.to(device) for item in batch]
            feature_ids, feature_vals, labels = batch
            outputs = model(feature_ids, feature_vals).squeeze()
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            pred_arr = np.hstack(
                [pred_arr, outputs.data.detach().cpu()]) if pred_arr.size else outputs.data.detach().cpu()
            label_arr = np.hstack(
                [label_arr, labels.data.detach().cpu()]) if label_arr.size else labels.data.detach().cpu()
        val_loss = running_loss / (itr + 1)
        torch.cuda.empty_cache()
    auc = roc_auc_score(label_arr, pred_arr)
    return val_loss, auc

def val_query(model, val_dataloader):
    model.eval()
    running_loss = 0
    criterion = torch.nn.BCEWithLogitsLoss()
    pred_arr = np.array([])
    label_arr = np.array([])
    with torch.no_grad():
        for itr, batch in enumerate(tqdm(val_dataloader)):
            batch = [[e.to(device) for e in item] if isinstance(item, list) else item.to(device) for item in batch]
            feature_ids, feature_vals, labels, support_data = batch
            outputs, _, _, _ = model(feature_ids, feature_vals, support_data)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            pred_arr = np.hstack(
                [pred_arr, outputs.data.detach().cpu()]) if pred_arr.size else outputs.data.detach().cpu()
            label_arr = np.hstack(
                [label_arr, labels.data.detach().cpu()]) if label_arr.size else labels.data.detach().cpu()
        val_loss = running_loss / (itr + 1)
        torch.cuda.empty_cache()
    auc = roc_auc_score(label_arr, pred_arr)
    return val_loss, auc

#### The DeepFM base learner, which will be used in the feature encoder and the shared predictor, respectively.

In [11]:
class DeepFM_encoder(nn.Module):
    def __init__(self, num_features, embedding_dim, num_fields, hidden_size=400):
        super(DeepFM_encoder, self).__init__()
#         num_fields -= 1
        self.num_features = num_features
        self.embedding_dim = embedding_dim
        self.num_fields = num_fields
        self.last_layer_dim = 400
        self.feature_embeddings = nn.Embedding(num_features, embedding_dim)
        torch.nn.init.xavier_normal_(self.feature_embeddings.weight)
        self.input_dim = embedding_dim * num_fields
        self.fc1 = nn.Linear(self.input_dim, hidden_size)
        self.fc2 = nn.Linear(hidden_size, hidden_size)
        self.fc3 = nn.Linear(hidden_size, self.last_layer_dim)
        self.fc4 = nn.Linear(self.last_layer_dim+self.embedding_dim, 1)
        

    def forward(self, feature_ids, feature_vals, return_hidden=False):
#         # exclude uid feature field
#         feature_ids = feature_ids[:,1:]
#         feature_vals = feature_vals[:,1:]
        # None*F*K
        input_embeddings = self.feature_embeddings(feature_ids)
        input_embeddings *= feature_vals.unsqueeze(dim=2)
        # None*K
        square_sum = torch.sum(input_embeddings ** 2, dim=1)
        sum_square = torch.sum(input_embeddings, dim=1) ** 2
        # None*K
        hidden_fm = (sum_square - square_sum) / 2
        # None*(F*K)
        input_embeddings_flatten = input_embeddings.view(-1, self.input_dim)
        hidden = nn.ReLU()(self.fc1(input_embeddings_flatten))
        hidden = nn.ReLU()(self.fc2(hidden))
        hidden_dnn =  nn.ReLU()(self.fc3(hidden))
        hidden_encoder = torch.cat([hidden_fm, hidden_dnn],dim=1)
        prediction = self.fc4(hidden_encoder).squeeze(1)
        if return_hidden:
            return prediction, hidden_encoder
        else:
            return prediction

### Train the shared predictor $\Psi$

In [9]:
train_df=pd.read_csv(PATH+'train_df.csv')
val_df = pd.read_csv(PATH+f'valid_df.csv')
test_df = pd.read_csv(PATH+f'test_df.csv')

# dataframe->pytorch dataset
train_dataset = CTR_Dataset(train_df)
val_dataset = CTR_Dataset(val_df)
test_dataset = CTR_Dataset(test_df)
num_fields = train_dataset.num_fields
num_features = 1+max([x.x_id.max().item() for x in [train_dataset, val_dataset, test_dataset]])

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)
# Define model and optimizer.
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
# Start training.
best_loss = np.inf
best_epoch = -1
best_auc = 0.5
for epoch in range(1):
    print(f"Starting epoch: {epoch} | phase: train | ⏰: {time.strftime('%H:%M:%S')}")
    model.train()
    running_loss = 0
    for itr, batch in enumerate(tqdm(train_dataloader)):
        batch = [item.to(device) for item in batch]
        feature_ids, feature_vals, labels = batch
        if feature_ids.shape[0]==1:
            break
        outputs = model(feature_ids, feature_vals).squeeze()
        loss = torch.nn.BCEWithLogitsLoss()(outputs, labels)
        loss.backward()
        running_loss += loss.detach().item()
        optimizer.step()
        optimizer.zero_grad()
    epoch_loss = running_loss / (itr+1)
    print(f"training loss of epoch {epoch}: {epoch_loss}")
    torch.cuda.empty_cache()
    
    state = {
    "epoch_loss": epoch_loss,
    "model": model.state_dict(),
    "optimizer": optimizer.state_dict(),
    }
    torch.save(state, f"predictor-{exp_id}.tar")

  0%|          | 0/690 [00:00<?, ?it/s]

Starting epoch: 0 | phase: train | ⏰: 15:54:19


100%|██████████| 690/690 [00:08<00:00, 78.44it/s] 

training loss of epoch 0: 0.6036580337994341





In [17]:
# fine-grained test on base model
print(f"Starting test | ⏰: {time.strftime('%H:%M:%S')}")
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
model = model.to(device)
checkpoint = torch.load(f"predictor-{exp_id}.tar", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])

base_test_losses = []
base_test_aucs = []

for i in rqdm(range(1,COLD_USER_THRESHOLD+1,1)):
    # omit users with <= i interactions.
    test_support_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[:i] if len(x)>i else x[:0])
    test_query_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[i:] if len(x)>i else x[:0])
    test_dataset = CTR_Dataset(test_query_set)
    test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)
    
    test_loss, test_auc = val(model, test_dataloader)
#     print(f"test loss of user group {i}: {test_loss}, auc: {test_auc}, gauc: {test_gauc}")

    base_test_losses += [test_loss]
    base_test_aucs += [test_auc]

print(f"cold start I: Loss: {sum(base_test_losses[:10])/10}, auc: {sum(base_test_aucs[:10])/10}")
print(f"cold start II: Loss: {sum(base_test_losses[10:20])/10}, auc: {sum(base_test_aucs[10:20])/10}")
print(f"cold start III: Loss: {sum(base_test_losses[20:30])/10}, auc: {sum(base_test_aucs[20:30])/10}")

Starting test | ⏰: 15:59:04


188it [00:01, 172.56it/s]
186it [00:01, 167.22it/s]
185it [00:00, 237.31it/s]
184it [00:00, 203.46it/s]
183it [00:00, 203.25it/s]
182it [00:01, 145.26it/s]
180it [00:00, 190.57it/s]
179it [00:01, 170.36it/s]
178it [00:00, 192.21it/s]
177it [00:00, 214.38it/s]
176it [00:00, 231.89it/s]
175it [00:00, 190.51it/s]
173it [00:01, 159.46it/s]
172it [00:01, 164.46it/s]
171it [00:00, 192.83it/s]
170it [00:00, 194.24it/s]
169it [00:00, 200.67it/s]
167it [00:00, 207.23it/s]
166it [00:01, 141.60it/s]
165it [00:00, 185.17it/s]
164it [00:00, 167.30it/s]
163it [00:00, 167.97it/s]
162it [00:00, 236.65it/s]
161it [00:00, 216.79it/s]
159it [00:00, 199.41it/s]
158it [00:01, 153.71it/s]
157it [00:00, 180.42it/s]
156it [00:01, 148.26it/s]
155it [00:00, 162.54it/s]
154it [00:00, 184.04it/s]

cold start I: Loss: 0.6021976912701057, auc: 0.7213255798748595
cold start II: Loss: 0.6048084336442361, auc: 0.7212235783528799
cold start III: Loss: 0.6067371029609409, auc: 0.7212622311097989





### RESUS model with NN base learner

In [18]:
class AdjustLayer(nn.Module):
    def __init__(self, init_scale=0.4, num_adjust=None, init_bias=0, base=1):
        super().__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_scale for i in range(num_adjust)]).unsqueeze(1))
        self.bias = nn.Parameter(torch.FloatTensor([init_bias for i in range(num_adjust)]).unsqueeze(1))

    def forward(self, x, num_samples):
        return x * (10**self.scale[num_samples-1]) + self.bias[num_samples-1]

class RESUS_NN(nn.Module):
    def __init__(self, num_fields, COLD_USER_THRESHOLD, encoder, predictor):
        super(RESUS_NN, self).__init__()
        self.num_fields = num_fields
        self.COLD_USER_THRESHOLD = COLD_USER_THRESHOLD
        self.predictor = predictor
        self.encoder = encoder
        self.L = nn.CrossEntropyLoss()
        self.adjust = AdjustLayer(num_adjust=COLD_USER_THRESHOLD)
        self.fc1 = nn.Linear(self.encoder.last_layer_dim+self.encoder.embedding_dim, 1)

    def forward(self, feature_ids, feature_vals, support_data, debug=False):
        # feature_ids: None*num_fields
        # feature_vals: None*num_fields
        # support_data: [x_id_support,x_val_support,y_support]
        # x_id_support: None*COLD_USER_THRESHOLD*num_fields
        # x_val_support: None*COLD_USER_THRESHOLD*num_fields
        # y_support: None*COLD_USER_THRESHOLD
        
        x_id_support, x_val_support, y_support = support_data
        feature_ids_concat = torch.cat([feature_ids.unsqueeze(1),x_id_support],dim=1) # None*(COLD_USER_THRESHOLD+1)*num_fields
        feature_vals_concat = torch.cat([feature_vals.unsqueeze(1),x_val_support],dim=1) # None*(COLD_USER_THRESHOLD+1)*num_fields
        feature_ids_concat = feature_ids_concat.view(-1,self.num_fields) # (None*(COLD_USER_THRESHOLD+1))*num_fields
        feature_vals_concat = feature_vals_concat.view(-1,self.num_fields) # (None*(COLD_USER_THRESHOLD+1))*num_fields
        output_predictor = self.predictor(feature_ids_concat, feature_vals_concat, return_hidden=False)
        output_predictor = output_predictor.view(-1, self.COLD_USER_THRESHOLD+1) # None*(COLD_USER_THRESHOLD+1)
        _, g_x_concat = self.encoder(feature_ids_concat, feature_vals_concat, return_hidden=True)
        g_x_concat = g_x_concat.view(-1, self.COLD_USER_THRESHOLD+1, g_x_concat.shape[1]) # None*(COLD_USER_THRESHOLD+1)*hidden_size
        g_x_hat = g_x_concat[:,[0],:] # None*1*hidden_size
        g_x_support = g_x_concat[:,1:,:] # None*COLD_USER_THRESHOLD*hidden_size
        num_samples = (y_support!=-1).sum(1) # None    
        distance = torch.abs(g_x_hat-g_x_support) # None*COLD_USER_THRESHOLD*hidden_size
        similar_score = self.fc1(distance).squeeze() # None*COLD_USER_THRESHOLD
        support_mask = (y_support==-1) # None*COLD_USER_THRESHOLD
        similar_score[support_mask] = float('-inf')
        similar_score_normalized = nn.Softmax(dim=1)(similar_score*1) # None*COLD_USER_THRESHOLD 
        delta_y = y_support-nn.Sigmoid()(output_predictor[:,1:]) #None*COLD_USER_THRESHOLD
        delta_y_hat = (delta_y*similar_score_normalized).sum(1,keepdim=True) # None
        prediction = self.adjust(delta_y_hat, num_samples) + output_predictor[:,[0]]        
        if debug:
            return X_nomask, X, y_support, nn.Sigmoid()(matmul(X, W)), matmul(X, delta_W), delta_W
        else:
            return prediction.squeeze(), y_support-nn.Sigmoid()(output_predictor[:,1:]),similar_score_normalized, delta_y_hat

### Training RESUS_NN

In [36]:
# load encoder
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
model = model.to(device)
checkpoint = torch.load(f"predictor-{exp_id}.tar", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
encoder = DeepFM_encoder(num_features, embedding_dim, num_fields)
encoder = encoder.to(device)

resus_nn = RESUS_NN(num_fields, COLD_USER_THRESHOLD, encoder, model).to(device)
optimizer = torch.optim.Adam(
    [
        {"params": resus_nn.encoder.parameters(), "lr": 0.001},
        {"params": resus_nn.fc1.parameters(), "lr": 0.001},
        {"params": resus_nn.adjust.parameters(), "lr": 0.001},
    ],
)

best_loss = np.inf
best_epoch = -1
best_auc = 0.5
train_df_gb_uid = train_df.groupby('uid')
num_users = max(train_df_gb_uid.groups.keys())+1

val_df_support = pd.read_csv(f'../data/val_df_support.csv')
val_df_query = pd.read_csv(f'../data/val_df_query.csv')

val_query_dataset = QueryWithSupportDataset(val_df_query,val_df_support,COLD_USER_THRESHOLD)
val_query_dataloader = DataLoader(val_query_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)

for epoch in range(num_epochs):
    print(f"Starting epoch: {epoch} | phase: train | ⏰: {time.strftime('%H:%M:%S')}")

    def sample_func(x):
        num_sample = np.random.randint(1,COLD_USER_THRESHOLD+1)
        if len(x)>num_sample:
            return x.sample(n=num_sample)
        else:
            return x

    train_support_df = train_df_gb_uid.apply(sample_func).reset_index(level=0, drop=True)
    train_query_df = pd.concat([train_df, train_support_df]).drop_duplicates(keep=False)
    train_query_dataset = QueryWithSupportDataset(train_query_df,train_support_df,COLD_USER_THRESHOLD)
    train_query_dataloader = DataLoader(train_query_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)

    # Start training
    resus_nn.train()
    running_loss = 0
    for itr, batch in enumerate(tqdm(train_query_dataloader)):
        batch = [[e.to(device) for e in item] if isinstance(item, list) else item.to(device) for item in batch]
        feature_ids, feature_vals, labels, support_data = batch
        outputs, predictor, delta_W_X, delta_W = resus_nn(feature_ids, feature_vals, support_data)            
        loss = torch.nn.BCEWithLogitsLoss()(outputs, labels)
        loss.backward()
        running_loss += loss.detach().item()
        optimizer.step()
        optimizer.zero_grad()
    epoch_loss = running_loss / (itr+1)
    print(f"training loss of epoch {epoch}: {epoch_loss}")
    torch.cuda.empty_cache()

    print(f"Starting epoch: {epoch} | phase: val | ⏰: {time.strftime('%H:%M:%S')}")
    state = {
    "epoch": epoch,
    "best_loss": best_loss,
    "best_auc": best_auc,
    "model": resus_nn.state_dict(),
    "optimizer": optimizer.state_dict(),
    }
    resus_nn.eval()
    val_loss, val_auc = val_query(resus_nn, val_query_dataloader)
    print(f"validation loss of epoch {epoch}: {val_loss}, auc: {val_auc}")
    if val_auc > best_auc:
        print("******** New optimal found, saving state ********")
        patience = overfit_patience
        state["best_loss"] = best_loss = val_loss
        state["best_auc"] = best_auc = val_auc
        best_epoch = epoch
        torch.save(state, f"RESUS_NN-{exp_id}.tar")
    else:
        patience -= 1
    if optimizer.param_groups[0]['lr'] <= 1e-7:
        print('LR less than 1e-7, stop training...')
        break
    if patience == 0:
        print('patience == 0, stop training...')
        break
    del train_support_df
    del train_query_df
    del train_query_dataset
    del train_query_dataloader
    gc.collect()

100%|██████████| 82/82 [00:40<00:00,  2.03it/s]


Starting epoch: 0 | phase: train | ⏰: 16:31:07


100%|██████████| 626/626 [04:33<00:00,  2.29it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 0: 0.5520654066492574
Starting epoch: 0 | phase: val | ⏰: 16:35:45


100%|██████████| 82/82 [00:38<00:00,  2.15it/s]


validation loss of epoch 0: 0.5791232498680673, auc: 0.7604671031665173
******** New optimal found, saving state ********
Starting epoch: 1 | phase: train | ⏰: 16:36:23


100%|██████████| 627/627 [04:27<00:00,  2.35it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 1: 0.5502091677565324
Starting epoch: 1 | phase: val | ⏰: 16:40:55


100%|██████████| 82/82 [00:36<00:00,  2.24it/s]


validation loss of epoch 1: 0.5750143135466227, auc: 0.7656053336739735
******** New optimal found, saving state ********
Starting epoch: 2 | phase: train | ⏰: 16:41:31


100%|██████████| 627/627 [04:31<00:00,  2.31it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 2: 0.5479272462344436
Starting epoch: 2 | phase: val | ⏰: 16:46:07


100%|██████████| 82/82 [00:37<00:00,  2.19it/s]


validation loss of epoch 2: 0.5736086535744551, auc: 0.7669373538071427
******** New optimal found, saving state ********
Starting epoch: 3 | phase: train | ⏰: 16:46:45


 12%|█▏        | 73/625 [00:33<04:16,  2.15it/s]


KeyboardInterrupt: 

In [52]:
# fine-grained test on nn model
print(f"Starting test | ⏰: {time.strftime('%H:%M:%S')}")
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
encoder = DeepFM_encoder(num_features, embedding_dim, num_fields)
resus_nn = RESUS_NN(num_fields, COLD_USER_THRESHOLD, encoder, model).to(device)
checkpoint = torch.load(f"RESUS_NN-{exp_id}.tar", map_location=torch.device('cpu'))
resus_nn.load_state_dict(checkpoint['model'])

resus_nn_test_losses = []
resus_nn_test_aucs = []

for i in tqdm(range(1,COLD_USER_THRESHOLD+1,1)):
    test_support_set = pd.read_csv(f'../data/test/test_df_support_{i}.csv')
    test_query_set = pd.read_csv(f'../data/test/test_df_query_{i}.csv')
    test_query_dataset = QueryWithSupportDataset(test_query_set,test_support_set, COLD_USER_THRESHOLD)
    test_query_dataloader = DataLoader(test_query_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)
    test_loss, test_auc = val_query(resus_nn, test_query_dataloader)
    
    print(f"test loss of user group {i}: {test_loss}, auc: {test_auc}")
    resus_nn_test_losses += [test_loss]
    resus_nn_test_aucs += [test_auc]
    
    del test_support_set
    del test_query_set
    del test_query_dataset
    del test_query_dataloader

print(f"cold start I: Loss: {sum(resus_nn_test_losses[:10])/10}, auc: {sum(resus_nn_test_aucs[:10])/10}")
print(f"cold start II: Loss: {sum(resus_nn_test_losses[10:20])/10}, auc: {sum(resus_nn_test_aucs[10:20])/10}")
print(f"cold start III: Loss: {sum(resus_nn_test_losses[20:30])/10}, auc: {sum(resus_nn_test_aucs[20:30])/10}")

  0%|          | 0/30 [00:00<?, ?it/s]

Starting test | ⏰: 20:01:26



  0%|          | 0/178 [00:00<?, ?it/s][A
  1%|          | 1/178 [00:03<11:28,  3.89s/it][A
  1%|          | 2/178 [00:04<05:15,  1.79s/it][A
  3%|▎         | 6/178 [00:04<01:34,  1.83it/s][A
  5%|▌         | 9/178 [00:05<01:05,  2.57it/s][A
  6%|▌         | 10/178 [00:05<01:02,  2.70it/s][A
  8%|▊         | 14/178 [00:06<00:46,  3.54it/s][A
 10%|▉         | 17/178 [00:07<00:39,  4.05it/s][A
 10%|█         | 18/178 [00:07<00:42,  3.79it/s][A
 12%|█▏        | 22/178 [00:08<00:35,  4.43it/s][A
 14%|█▍        | 25/178 [00:08<00:33,  4.60it/s][A
 15%|█▍        | 26/178 [00:09<00:34,  4.43it/s][A
 17%|█▋        | 30/178 [00:09<00:31,  4.70it/s][A
 19%|█▊        | 33/178 [00:10<00:29,  4.85it/s][A
 19%|█▉        | 34/178 [00:10<00:29,  4.81it/s][A
 21%|██        | 37/178 [00:10<00:22,  6.18it/s][A
 21%|██▏       | 38/178 [00:11<00:33,  4.19it/s][A
 23%|██▎       | 41/178 [00:12<00:29,  4.59it/s][A
 24%|██▎       | 42/178 [00:12<00:27,  4.95it/s][A
 24%|██▍       | 43/178 

test loss of user group 1: 0.5944164870830064, auc: 0.7232352468956003



  0%|          | 0/176 [00:00<?, ?it/s][A
  1%|          | 1/176 [00:02<07:19,  2.51s/it][A
  1%|          | 2/176 [00:02<03:29,  1.20s/it][A
  5%|▍         | 8/176 [00:02<00:35,  4.68it/s][A
  6%|▋         | 11/176 [00:04<00:52,  3.16it/s][A
 10%|▉         | 17/176 [00:05<00:41,  3.79it/s][A
 11%|█         | 19/176 [00:06<00:38,  4.04it/s][A
 14%|█▍        | 25/176 [00:07<00:34,  4.34it/s][A
 15%|█▍        | 26/176 [00:07<00:36,  4.06it/s][A
 18%|█▊        | 32/176 [00:07<00:20,  7.05it/s][A
 20%|█▉        | 35/176 [00:09<00:32,  4.34it/s][A
 23%|██▎       | 40/176 [00:09<00:22,  6.07it/s][A
 24%|██▍       | 42/176 [00:11<00:35,  3.76it/s][A
 26%|██▌       | 46/176 [00:11<00:25,  5.19it/s][A
 27%|██▋       | 48/176 [00:11<00:22,  5.80it/s][A
 28%|██▊       | 50/176 [00:12<00:36,  3.42it/s][A
 31%|███       | 54/176 [00:13<00:24,  5.07it/s][A
 32%|███▏      | 56/176 [00:13<00:20,  5.98it/s][A
 33%|███▎      | 58/176 [00:14<00:34,  3.43it/s][A
 34%|███▍      | 60/176

test loss of user group 2: 0.5991324579173868, auc: 0.7247234581584001



  0%|          | 0/175 [00:00<?, ?it/s][A
  1%|          | 1/175 [00:05<14:39,  5.05s/it][A
  5%|▍         | 8/175 [00:06<02:16,  1.22it/s][A
  7%|▋         | 2/30 [01:27<20:27, 43.83s/it]


KeyboardInterrupt: 

In [11]:
# class AdjustLayer(nn.Module):
#     def __init__(self, init_scale=1, num_adjust=None, init_bias=0, base=1):
#         super().__init__()
#         self.scale = nn.Parameter(torch.FloatTensor([init_scale]))
#         self.bias = nn.Parameter(torch.FloatTensor([init_bias]))
#         self.base = base

#     def forward(self, x, num_samples):
#         return x * self.scale + self.bias

class AdjustLayer(nn.Module):
    def __init__(self, init_scale=0.4, num_adjust=None, init_bias=0, base=1):
        super().__init__()
        self.scale = nn.Parameter(torch.FloatTensor([init_scale for i in range(num_adjust)]).unsqueeze(1))
        self.bias = nn.Parameter(torch.FloatTensor([init_bias for i in range(num_adjust)]).unsqueeze(1))

    def forward(self, x, num_samples):
        return x * (10**self.scale[num_samples-1]) + self.bias[num_samples-1]

class LambdaLayer(nn.Module):
    def __init__(self, learn_lambda=True, num_lambda=None, init_lambda=0.001, base=1):
        super().__init__()
        self.l = torch.FloatTensor([init_lambda]) # COLD
        self.base = base
        self.l = nn.Parameter(self.l, requires_grad=learn_lambda)

    def forward(self, x, n_samples):
        #   x: None*COLD*COLD
        #   n_samples: None
        return x * torch.abs(self.l.unsqueeze(1).unsqueeze(2))

# RR
class RESUS_RR(nn.Module):
    def __init__(self, num_fields, COLD_USER_THRESHOLD, encoder, predictor):
        super(RESUS_RR, self).__init__()
        self.num_fields = num_fields
        self.COLD_USER_THRESHOLD = COLD_USER_THRESHOLD
        self.predictor = predictor
        self.encoder = encoder
        self.lambda_rr = LambdaLayer(learn_lambda=True, num_lambda=COLD_USER_THRESHOLD)
        self.L = nn.CrossEntropyLoss()
        self.adjust = AdjustLayer(1, num_adjust=COLD_USER_THRESHOLD)     
        
    def rr_standard(self, x, n_samples, yrr_binary, linsys=False):
#         x /= n_samples
        I = torch.eye(x.shape[1]).to(x)

        if not linsys:
            w = mm(mm(inv(mm(t(x, 0, 1), x) + self.lambda_rr(I)), t(x, 0, 1)), yrr_binary)
        else:
            A = mm(t_(x), x) + self.lambda_rr(I)
            v = mm(t_(x), yrr_binary)
            w, _ = solve(v, A)

        return w

    def rr_woodbury(self, X, n_samples, yrr_binary, linsys=False):
        #   X: None*COLD_USER_THRESHOLD*(hidden_size+1)
        #   n_samples: None
#         x = X/torch.sqrt(n_samples.float()).unsqueeze(1).unsqueeze(2)    #   x: None*COLD*(hidden+1)
        x = X
        I = torch.eye(x.shape[1]).unsqueeze(0).repeat(x.shape[0],1,1).to(x)    # None*COLD*COLD
        if not linsys:
            w = matmul(matmul(t(x, 1, 2), inv(matmul(x, t(x, 1, 2)) + self.lambda_rr(I, n_samples))), yrr_binary)
        else:
            A = mm(x, t_(x)) + self.lambda_rr(I)
            v = yrr_binary
            w_, _ = solve(v, A)
            w = mm(t_(x), w_)
        return w

    def forward(self, feature_ids, feature_vals, support_data, debug=False):
        # feature_ids: None*num_fields
        # feature_vals: None*num_fields
        # support_data: [x_id_support,x_val_support,y_support]
        # x_id_support: None*COLD_USER_THRESHOLD*num_fields
        # x_val_support: None*COLD_USER_THRESHOLD*num_fields
        # y_support: None*COLD_USER_THRESHOLD
        
        x_id_support, x_val_support, y_support = support_data
        feature_ids_concat = torch.cat([feature_ids.unsqueeze(1),x_id_support],dim=1) # None*(COLD_USER_THRESHOLD+1)*num_fields
        feature_vals_concat = torch.cat([feature_vals.unsqueeze(1),x_val_support],dim=1) # None*(COLD_USER_THRESHOLD+1)*num_fields
        feature_ids_concat = feature_ids_concat.view(-1,self.num_fields) # (None*(COLD_USER_THRESHOLD+1))*num_fields
        feature_vals_concat = feature_vals_concat.view(-1,self.num_fields) # (None*(COLD_USER_THRESHOLD+1))*num_fields
        output_predictor = self.predictor(feature_ids_concat, feature_vals_concat, return_hidden=False)
        output_predictor = output_predictor.view(-1, self.COLD_USER_THRESHOLD+1) # None*(COLD_USER_THRESHOLD+1)
        _, g_x_concat = self.encoder(feature_ids_concat, feature_vals_concat, return_hidden=True)
        g_x_concat = g_x_concat.view(-1, self.COLD_USER_THRESHOLD+1, g_x_concat.shape[1]) # None*(COLD_USER_THRESHOLD+1)*hidden_size
        g_x_hat = g_x_concat[:,[0],:] # None*1*hidden_size
        g_x_support = g_x_concat[:,1:,:] # None*COLD_USER_THRESHOLD*hidden_size
        # output_encoder
        y_x_hat = output_predictor[:,0] # None
        X_mask = (y_support!=-1).int().float().unsqueeze(2) # None*COLD_USER_THRESHOLD*1
        num_samples = (y_support!=-1).sum(1) # None
        ones = torch.ones((g_x_support.shape[0],g_x_support.shape[1])).unsqueeze(2).to(g_x_hat) # None*COLD_USER_THRESHOLD*1
        X_nomask = torch.cat((g_x_support, ones), 2) # None*COLD_USER_THRESHOLD*(hidden_size+1)
        X = X_nomask*X_mask
#         W = torch.cat([self.encoder.fc4.weight.squeeze(), self.encoder.fc4.bias]).unsqueeze(1) # (hidden_size+1)*1  
        delta_W = self.rr_woodbury(X, num_samples, y_support.unsqueeze(2)-nn.Sigmoid()(output_predictor[:,1:].unsqueeze(2))) # None*(hidden_size+1)*1    
        delta_w = delta_W[:,:-1] # None*(hidden_size)*1     
        delta_b = delta_W[:,-1] # None*1     
        out = matmul(g_x_hat, delta_w).squeeze(2) + delta_b # None*1
        prediction = self.adjust(out, num_samples) + output_predictor[:,[0]]
        if debug:
            return X_nomask, X, y_support, nn.Sigmoid()(matmul(X, W)), matmul(X, delta_W), delta_W
        else:
            return prediction.squeeze(), y_support.unsqueeze(2)-nn.Sigmoid()(output_predictor[:,1:].unsqueeze(2)), matmul(X, delta_W), delta_W

In [14]:
# load encoder
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
model = model.to(device)
checkpoint = torch.load(f"checkpoint/predictor-{dataset}-{exp_id}.tar", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint['model'])
encoder = DeepFM_encoder(num_features, embedding_dim, num_fields)
encoder = encoder.to(device)

resus_rr = RESUS_RR(num_fields, COLD_USER_THRESHOLD, encoder, model).to(device)
optimizer = torch.optim.Adam(
    [
        {"params": resus_rr.encoder.parameters(), "lr": 0.001},
        {"params": resus_rr.adjust.parameters(), "lr": 0.001},
        {"params": resus_rr.lambda_rr.parameters(), "lr": 0.1},
    ],
)

best_loss = np.inf
best_epoch = -1
best_auc = 0.5
train_df_gb_uid = train_df.groupby('uid')
num_users = max(train_df_gb_uid.groups.keys())+1

val_df_support = val_df.groupby('uid').apply(lambda x: x[:COLD_USER_THRESHOLD] if len(x)>COLD_USER_THRESHOLD else x[:-1])
val_df_query = val_df.groupby('uid').apply(lambda x: x[COLD_USER_THRESHOLD:] if len(x)>COLD_USER_THRESHOLD else x[-1:])

val_query_dataset = QueryWithSupportDataset(val_df_query,val_df_support,COLD_USER_THRESHOLD)
val_query_dataloader = DataLoader(val_query_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)

for epoch in range(num_epochs):
    print(f"Starting epoch: {epoch} | phase: train | ⏰: {time.strftime('%H:%M:%S')}")

    # Random sample support set
    def sample_func(x):
        num_sample = np.random.randint(1,COLD_USER_THRESHOLD+1)
        if len(x)>num_sample:
            return x.sample(n=num_sample)
        else:
            return x

    train_support_df = train_df_gb_uid.apply(sample_func).reset_index(level=0, drop=True)
    train_query_df = pd.concat([train_df, train_support_df]).drop_duplicates(keep=False)
    train_query_dataset = QueryWithSupportDataset(train_query_df,train_support_df,COLD_USER_THRESHOLD)
    train_query_dataloader = DataLoader(train_query_dataset, batch_size, shuffle=True, num_workers=8, pin_memory=True)

    # Start training
    rr.train()
    running_loss = 0
    for itr, batch in enumerate(tqdm(train_query_dataloader)):
        batch = [[e.to(device) for e in item] if isinstance(item, list) else item.to(device) for item in batch]
        feature_ids, feature_vals, labels, support_data = batch
        outputs, delta_y_support, delta_W_X, delta_W = resus_rr(feature_ids, feature_vals, support_data)            
        loss = torch.nn.BCEWithLogitsLoss()(outputs, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        running_loss += loss.detach().item()
    epoch_loss = running_loss / (itr+1)
    print(f"training loss of epoch {epoch}: {epoch_loss}")
    torch.cuda.empty_cache()

    print(f"Starting epoch: {epoch} | phase: val | ⏰: {time.strftime('%H:%M:%S')}")
    state = {
    "epoch": epoch,
    "best_loss": best_loss,
    "best_auc": best_auc,
    "model": resus_rr.state_dict(),
    "optimizer": optimizer.state_dict(),
    }
    resus_rr.eval()
    val_loss, val_auc = val_query(resus_rr, val_query_dataloader, gauc_col=0)
    print(f"validation loss of epoch {epoch}: {val_loss}, auc: {val_auc}")

    if val_auc > best_auc:
        print("******** New optimal found, saving state ********")
        patience = overfit_patience
        state["best_loss"] = best_loss = val_loss
        state["best_auc"] = best_auc = val_auc
        best_epoch = epoch
        torch.save(state, f"RESUS_RR-{exp_id}.tar")
    else:
        patience -= 1
    if optimizer.param_groups[0]['lr'] <= 1e-7:
        print('LR less than 1e-7, stop training...')
        break
    if patience == 0:
        print('patience == 0, stop training...')
        break
    del train_support_df
    del train_query_df
    del train_query_dataset
    del train_query_dataloader
    gc.collect()

Starting epoch: 0 | phase: train | ⏰: 22:00:25


100%|██████████| 627/627 [04:31<00:00,  2.31it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 0: 0.6020512128370611
Starting epoch: 0 | phase: val | ⏰: 22:05:02


100%|██████████| 82/82 [00:35<00:00,  2.32it/s]


**************************************************
validation loss of epoch 0: 0.5822683681802052, auc: 0.7578080768274674, gauc: 0.7342
******** New optimal found, saving state ********
Starting epoch: 1 | phase: train | ⏰: 22:05:39


100%|██████████| 626/626 [04:22<00:00,  2.38it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 1: 0.5614815890408171
Starting epoch: 1 | phase: val | ⏰: 22:10:08


100%|██████████| 82/82 [00:36<00:00,  2.25it/s]


**************************************************
validation loss of epoch 1: 0.5773067957744366, auc: 0.7612202199841563, gauc: 0.7342
******** New optimal found, saving state ********
Starting epoch: 2 | phase: train | ⏰: 22:10:45


100%|██████████| 626/626 [04:44<00:00,  2.20it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 2: 0.5590110072693505
Starting epoch: 2 | phase: val | ⏰: 22:15:36


100%|██████████| 82/82 [00:43<00:00,  1.87it/s]


**************************************************
validation loss of epoch 2: 0.5744686297527174, auc: 0.7646792612175957, gauc: 0.7376
******** New optimal found, saving state ********
Starting epoch: 3 | phase: train | ⏰: 22:16:21


100%|██████████| 626/626 [05:17<00:00,  1.97it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 3: 0.552932007529865
Starting epoch: 3 | phase: val | ⏰: 22:21:45


100%|██████████| 82/82 [00:36<00:00,  2.26it/s]


**************************************************
validation loss of epoch 3: 0.5741228403114691, auc: 0.7657776678512074, gauc: 0.7375
******** New optimal found, saving state ********
Starting epoch: 4 | phase: train | ⏰: 22:22:22


100%|██████████| 627/627 [04:23<00:00,  2.38it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 4: 0.5475271836137087
Starting epoch: 4 | phase: val | ⏰: 22:26:52


100%|██████████| 82/82 [00:36<00:00,  2.24it/s]


**************************************************
validation loss of epoch 4: 0.5687021751229356, auc: 0.771742974267069, gauc: 0.7404
******** New optimal found, saving state ********
Starting epoch: 5 | phase: train | ⏰: 22:27:30


100%|██████████| 626/626 [05:21<00:00,  1.95it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 5: 0.5443475951973241
Starting epoch: 5 | phase: val | ⏰: 22:32:57


100%|██████████| 82/82 [00:39<00:00,  2.06it/s]


**************************************************
validation loss of epoch 5: 0.5697358541372346, auc: 0.771053800178204, gauc: 0.7405
Starting epoch: 6 | phase: train | ⏰: 22:33:38


100%|██████████| 627/627 [04:32<00:00,  2.30it/s]
  0%|          | 0/82 [00:00<?, ?it/s]

training loss of epoch 6: 0.5428969441798696
Starting epoch: 6 | phase: val | ⏰: 22:38:17


100%|██████████| 82/82 [00:41<00:00,  1.96it/s]


**************************************************
validation loss of epoch 6: 0.5720918353010969, auc: 0.7685317605616563, gauc: 0.7393
patience == 0, stop training...


In [15]:
# fine-grained test on rr model
print(f"Starting test | ⏰: {time.strftime('%H:%M:%S')}")
model = DeepFM_encoder(num_features, embedding_dim, num_fields)
encoder = DeepFM_encoder(num_features, embedding_dim, num_fields)
resus_rr = RESUS_RR(num_fields, COLD_USER_THRESHOLD, encoder, model).to(device)
checkpoint = torch.load(f"RESUS_RR-{exp_id}.tar", map_location=torch.device('cpu'))
resus_rr.load_state_dict(checkpoint['model'])

resus_rr_test_losses = []
resus_rr_test_aucs = []

for i in range(1,COLD_USER_THRESHOLD+1,1):
    # omit users with <= i interactions.
    test_support_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[:i] if len(x)>i else x[:0])
    test_query_set = test_df.groupby('uid',as_index=False).apply(
        lambda x: x[i:] if len(x)>i else x[:0])
    test_query_dataset = QueryWithSupportDataset(test_query_set,test_support_set,COLD_USER_THRESHOLD)
    test_query_dataloader = DataLoader(test_query_dataset, batch_size, shuffle=False, num_workers=8, pin_memory=True)
    
    test_loss, test_auc = val_query(resus_rr, test_query_dataloader)
#     print(f"test loss of user group {i}: {test_loss}, auc: {test_auc}, gauc: {test_gauc}")

    resus_rr_test_losses += [test_loss]
    resus_rr_test_aucs += [test_auc]
    resus_rr_test_gaucs += [test_gauc]

Starting test | ⏰: 22:39:00


100%|██████████| 188/188 [01:07<00:00,  2.80it/s]


**************************************************
test loss of user group 1: 0.6015202531472166, auc: 0.7209660395056977, gauc: 0.7264


100%|██████████| 186/186 [01:08<00:00,  2.72it/s]


**************************************************
test loss of user group 2: 0.5984873391928212, auc: 0.7252224748008752, gauc: 0.7264


100%|██████████| 185/185 [01:06<00:00,  2.79it/s]


**************************************************
test loss of user group 3: 0.5936238195445086, auc: 0.7311965770525354, gauc: 0.7265


100%|██████████| 184/184 [01:05<00:00,  2.79it/s]


**************************************************
test loss of user group 4: 0.5947355258723964, auc: 0.7309812105466249, gauc: 0.7259


100%|██████████| 183/183 [01:02<00:00,  2.95it/s]


**************************************************
test loss of user group 5: 0.5898863222103953, auc: 0.7356885070646262, gauc: 0.7274


100%|██████████| 182/182 [01:02<00:00,  2.92it/s]


**************************************************
test loss of user group 6: 0.5868191691217842, auc: 0.7394554945166116, gauc: 0.728


100%|██████████| 180/180 [00:57<00:00,  3.12it/s]


**************************************************
test loss of user group 7: 0.5867807323733966, auc: 0.740860450730772, gauc: 0.7282


100%|██████████| 179/179 [00:53<00:00,  3.35it/s]


**************************************************
test loss of user group 8: 0.585190556402313, auc: 0.7426057222023839, gauc: 0.729


100%|██████████| 178/178 [01:26<00:00,  2.06it/s]


**************************************************
test loss of user group 9: 0.5851454795076606, auc: 0.7444976227776436, gauc: 0.7289


100%|██████████| 177/177 [01:21<00:00,  2.17it/s]


**************************************************
test loss of user group 10: 0.5844121846438801, auc: 0.7450366623176242, gauc: 0.7283


100%|██████████| 176/176 [01:25<00:00,  2.06it/s]


**************************************************
test loss of user group 11: 0.5834116547961127, auc: 0.7482898624509837, gauc: 0.7293


100%|██████████| 175/175 [01:29<00:00,  1.96it/s]


**************************************************
test loss of user group 12: 0.5828133143697466, auc: 0.7474158766297216, gauc: 0.7284


100%|██████████| 173/173 [01:26<00:00,  2.00it/s]


**************************************************
test loss of user group 13: 0.5792231607988391, auc: 0.7518364240506221, gauc: 0.7301


100%|██████████| 172/172 [01:15<00:00,  2.28it/s]


**************************************************
test loss of user group 14: 0.575671256974686, auc: 0.755244156452583, gauc: 0.7321


100%|██████████| 171/171 [01:23<00:00,  2.05it/s]


**************************************************
test loss of user group 15: 0.574589408977687, auc: 0.7567987231737654, gauc: 0.7323


100%|██████████| 170/170 [01:30<00:00,  1.87it/s]


**************************************************
test loss of user group 16: 0.5744954805163777, auc: 0.7564698981024154, gauc: 0.7324


100%|██████████| 169/169 [01:30<00:00,  1.87it/s]


**************************************************
test loss of user group 17: 0.5779296667618159, auc: 0.7537078567249336, gauc: 0.7305


100%|██████████| 167/167 [01:28<00:00,  1.88it/s]


**************************************************
test loss of user group 18: 0.5750112080288504, auc: 0.756843322216303, gauc: 0.7322


100%|██████████| 166/166 [01:26<00:00,  1.92it/s]


**************************************************
test loss of user group 19: 0.5739807369838278, auc: 0.7581511221578331, gauc: 0.7329


100%|██████████| 165/165 [01:09<00:00,  2.39it/s]


**************************************************
test loss of user group 20: 0.5789835684227221, auc: 0.7553604257243597, gauc: 0.731


100%|██████████| 164/164 [01:22<00:00,  1.99it/s]


**************************************************
test loss of user group 21: 0.5725730468587178, auc: 0.7609991321031357, gauc: 0.7346


100%|██████████| 163/163 [01:17<00:00,  2.09it/s]


**************************************************
test loss of user group 22: 0.5737305916160163, auc: 0.7595537198377758, gauc: 0.7331


100%|██████████| 162/162 [01:16<00:00,  2.12it/s]


**************************************************
test loss of user group 23: 0.5767630014890506, auc: 0.7572368021887215, gauc: 0.7317


100%|██████████| 161/161 [01:38<00:00,  1.63it/s]


**************************************************
test loss of user group 24: 0.5719880950376854, auc: 0.7610868832716255, gauc: 0.7344


100%|██████████| 159/159 [01:22<00:00,  1.93it/s]


**************************************************
test loss of user group 25: 0.5705822190773562, auc: 0.7631484007970598, gauc: 0.7355


100%|██████████| 158/158 [01:19<00:00,  1.98it/s]


**************************************************
test loss of user group 26: 0.5715372983036162, auc: 0.7640026430526785, gauc: 0.7356


100%|██████████| 157/157 [01:22<00:00,  1.91it/s]


**************************************************
test loss of user group 27: 0.5712724572913662, auc: 0.7633536241008124, gauc: 0.7345


100%|██████████| 156/156 [01:10<00:00,  2.21it/s]


**************************************************
test loss of user group 28: 0.5725909863144923, auc: 0.7619593882282615, gauc: 0.7334


100%|██████████| 155/155 [01:19<00:00,  1.96it/s]


**************************************************
test loss of user group 29: 0.5725673644773421, auc: 0.7620837799395884, gauc: 0.7339


100%|██████████| 154/154 [01:18<00:00,  1.96it/s]


**************************************************
test loss of user group 30: 0.5708017151851159, auc: 0.7638790657422305, gauc: 0.7353


In [16]:
print('test losses')
for loss in rr_test_losses:
    print(loss)
print('test aucs')
for loss in rr_test_aucs:
    print(loss)
print('test gaucs')
for loss in rr_test_gaucs:
    print(loss)

test losses
0.6015202531472166
0.5984873391928212
0.5936238195445086
0.5947355258723964
0.5898863222103953
0.5868191691217842
0.5867807323733966
0.585190556402313
0.5851454795076606
0.5844121846438801
0.5834116547961127
0.5828133143697466
0.5792231607988391
0.575671256974686
0.574589408977687
0.5744954805163777
0.5779296667618159
0.5750112080288504
0.5739807369838278
0.5789835684227221
0.5725730468587178
0.5737305916160163
0.5767630014890506
0.5719880950376854
0.5705822190773562
0.5715372983036162
0.5712724572913662
0.5725909863144923
0.5725673644773421
0.5708017151851159
test aucs
0.7209660395056977
0.7252224748008752
0.7311965770525354
0.7309812105466249
0.7356885070646262
0.7394554945166116
0.740860450730772
0.7426057222023839
0.7444976227776436
0.7450366623176242
0.7482898624509837
0.7474158766297216
0.7518364240506221
0.755244156452583
0.7567987231737654
0.7564698981024154
0.7537078567249336
0.756843322216303
0.7581511221578331
0.7553604257243597
0.7609991321031357
0.7595537198377