In [None]:
import sys
sys.path.insert(0,'..')
from data.preprocess import setup_path, DOMAINS, ROOT_PATH
data_path = ROOT_PATH + "experiments/fedct/"
target_path = data_path + "transfer_data_small/"
# target_path = data_path + "transfer_data/"
setup_path(target_path, is_dir = True)
domains = DOMAINS

In [None]:
# Choose your own log path

domain_model_logs = {
# "Automotive": data_path + "env_logs/fedct_train_and_eval_MF_Automotive_lr0.00001_reg1.0_losspointwise.log",
"Books": data_path + "env_logs/fedct_train_and_eval_MF_Books_lr0.00001_reg0.1_losspointwise.log",
# "CDs_and_Vinyl": data_path + "env_logs/fedct_train_and_eval_MF_CDs_and_Vinyl_lr0.00001_reg0.1_losspairwisebpr.log",
"Clothing_Shoes_and_Jewelry": data_path + "env_logs/fedct_train_and_eval_MF_Clothing_Shoes_and_Jewelry_lr0.00001_reg0.1_losspairwisebpr.log",
"Electronics": data_path + "env_logs/fedct_train_and_eval_MF_Electronics_lr0.00001_reg0.1_losspointwise.log"
# "Grocery_and_Gourmet_Food": data_path + "env_logs/fedct_train_and_eval_MF_Grocery_and_Gourmet_Food_lr0.00001_reg1.0_losspointwise.log",
# "Home_and_Kitchen": data_path + "env_logs/fedct_train_and_eval_MF_Home_and_Kitchen_lr0.00001_reg1.0_losspointwise.log",
# "Kindle_Store": data_path + "env_logs/fedct_train_and_eval_MF_Kindle_Store_lr0.00001_reg0.1_losspairwisebpr.log",
# "Movies_and_TV": data_path + "env_logs/fedct_train_and_eval_MF_Movies_and_TV_lr0.000003_reg1.0_losspointwise.log",
# "Office_Products": data_path + "env_logs/fedct_train_and_eval_MF_Office_Products_lr0.000003_reg1.0_losspointwise.log",
# "Patio_Lawn_and_Garden": data_path + "env_logs/fedct_train_and_eval_MF_Patio_Lawn_and_Garden_lr0.00001_reg3.0_losspointwise.log",
# "Pet_Supplies": data_path + "env_logs/fedct_train_and_eval_MF_Pet_Supplies_lr0.00001_reg0.1_losspairwisebpr.log",
# "Sports_and_Outdoors": data_path + "env_logs/fedct_train_and_eval_MF_Sports_and_Outdoors_lr0.000003_reg1.0_losspointwise.log",
# "Tools_and_Home_Improvement": data_path + "env_logs/fedct_train_and_eval_MF_Tools_and_Home_Improvement_lr0.000003_reg1.0_losspointwise.log",
# "Toys_and_Games": data_path + "env_logs/fedct_train_and_eval_MF_Toys_and_Games_lr0.00001_reg0.1_losspointwise.log",
# "Video_Games": data_path + "env_logs/fedct_train_and_eval_MF_Video_Games_lr0.00001_reg0.1_losspairwisebpr.log"
}

In [None]:
with open(target_path + 'domain_model_logs.txt', 'w') as fout:
    fout.write(str(domain_model_logs))

## 1. Cross-domain User Information

In [None]:
from tqdm import tqdm
import pandas as pd

def get_user_vocab(vocab_path):
    item_vocab = pd.read_table(vocab_path, index_col = 1)
    value_idx = item_vocab[item_vocab['field_name'] == "UserID"][['idx']]
    value_idx = value_idx[~value_idx.index.duplicated(keep='first')].to_dict(orient = 'index')
    vocab = {str(k): vMap['idx'] for k,vMap in value_idx.items()}
    return vocab

def get_cross_domain_user(domains, data_dir):
    user_domain_ids = {}
    cross_domain_users = {}
    domain_vocabs = {d: get_user_vocab(data_dir + "meta_data/" + d + "_user_fields.vocab") \
                     for d in domains}
    for domain_id, source_domain in tqdm(enumerate(domains)):
        # user set
        source_vocab = domain_vocabs[source_domain]
        # domain-specific user id
        for uid,idx in source_vocab.items():
            if uid not in user_domain_ids:
                user_domain_ids[uid] = [0]*len(domains)
            user_domain_ids[uid][domain_id] = idx
        # cross-domain common user lists
        for target_domain in domains:
            if target_domain != source_domain:
                target_vocab = domain_vocabs[target_domain]
                common_users = [uid for uid, idx in target_vocab.items() \
                                if uid in source_vocab]
                cross_domain_users[f"{source_domain}@{target_domain}"] = common_users
    all_user = list(user_domain_ids.keys())
    for idx,uid in enumerate(all_user):
        user_domain_ids[uid].append(idx+1)
    return cross_domain_users, user_domain_ids

In [None]:
domains = list(domain_model_logs.keys())
CDU, U = get_cross_domain_user(domains, data_path + "domain_data/")
print("#user: " + str(len(U)))
print("#user\tsource@target")
for k,v in CDU.items():
    print(len(v),'\t',k)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
n_domain_per_user = {uid: sum([1 for f in idx[:-1] if f != 0])for uid, idx in U.items()}
n_domain_count = [0] * len(domains)
for uid, c in n_domain_per_user.items():
    n_domain_count[c-1] += 1
print(n_domain_count)
plt.figure(figsize = (0.7 * len(domains),3))
plt.bar(np.arange(1,len(domains)+1),np.log(n_domain_count))
plt.ylabel('Log frequency')
plt.xlabel('#domain the user interacted')
plt.show()

## 2. Setup Domain Transfer Data

In [None]:
import pandas as pd
df = pd.DataFrame.from_dict(U,orient='index',columns = domains + ['all'])
df.to_csv(target_path + 'id_train.tsv', sep = '\t')

In [None]:
df = pd.read_csv(target_path + 'id_train.tsv', header = 0, sep = '\t',
                 names = ["UserID"] + domains + ['All'])
df = df[["UserID"] + domains]
df[:3]

In [None]:
from tqdm import tqdm
import numpy as np
def split_cold_start_data(df, ratio = [0.2,0.8]):
    print("Build user history")
    user_hist = {}
    for pos,row in tqdm(enumerate(df.values)):
        u, *record = row
        if u not in user_hist:
            user_hist[u] = list()
        user_hist[u].append(pos)
    print("Holdout user histories")
    test_indices = df["UserID"]==-1
    for u,H in tqdm(user_hist.items()):
        # nTest = max(int(len(H) * ratio[1]), 1)
        if np.random.random() > ratio[0]:
            test_indices.iloc[H] = True
    testset = df[test_indices]
    trainset = df[~test_indices]
    return trainset, testset

In [None]:
import pandas as pd

for domain_id, target_domain in enumerate(domains):
    print(target_domain)
    target_data = pd.read_csv(data_path + "domain_data/tsv_data/" + target_domain + "_test_cold.tsv", sep = '\t')
    valset, testset = split_cold_start_data(target_data)
    print(len(valset), len(testset))
    print("number of users in eval: ", len(valset['UserID'].unique()), len(testset['UserID'].unique()))
    valset.to_csv(target_path + target_domain + "_val.tsv")
    testset.to_csv(target_path + target_domain + "_test.tsv")

## 3. Reader Example --- Cold Start Transfer Learning 

In [None]:
from reader.ColdStartTransferEnvironment import ColdStartTransferEnvironment
from argparse import Namespace
argstr = f"Namespace(data_file='{target_path}', domain_model_file='{target_path}domain_model_logs.txt', \
target='Video_Games', n_neg_val=100, n_neg_test=1000)"
reader_args = eval(argstr)

In [None]:
reader = ColdStartTransferEnvironment(reader_args)

In [None]:
print(reader.target_domain,'\n',reader.domains)

In [None]:
domain_dim_size = reader.user_emb_size[reader.target_domain]

#### 3.2 Example Cold Start Test

In [None]:
import torch
from reader.NextItemReader import sample_negative
n_user = 100
n_pos, n_neg = 5, 50
dummy_encoding = torch.randn(n_user, domain_dim_size)
random_items = sample_negative([], reader.all_item_candidates, n_user*n_pos, replace = True)
negative_items = sample_negative([], reader.all_item_candidates, n_user*n_neg, replace = True)
user_emb, user_bias = dummy_encoding[:,:-1], dummy_encoding[:,-1]

In [None]:
with torch.no_grad():
    pos_out = reader.target_model.forward_with_emb({'user_emb': user_emb, 'user_bias': user_bias,
                                                    'ItemID': torch.tensor(random_items).view(n_user,n_pos)})
    print(pos_out['preds'].shape)
    print(pos_out['preds'][0])
    print(pos_out['reg'])
    neg_out = reader.target_model.forward_with_emb({'user_emb': user_emb, 'user_bias': user_bias,
                                                    'ItemID': torch.tensor(negative_items).view(n_user,n_neg)})
    print(neg_out['preds'].shape)
    print(neg_out['preds'][0])
    print(neg_out['reg'])

In [None]:
pos_preds = torch.sigmoid(pos_out['preds'] + 0.1)
neg_preds = torch.sigmoid(neg_out['preds'] - 0.1)

In [None]:
pos_pred = pos_preds
neg_pred = neg_preds
# dummy mask for positive samples
pos_mask = torch.zeros_like(pos_pred)
for i in range(n_user):
    L = (n_pos - i - 1) % n_pos + 1
    pos_mask[i][-L:] = 1
k_list = [1,10,50]
max_k = max(k_list)

In [None]:
from utils import init_ranking_report
b = 1 / torch.log2(torch.arange(2,max_k+2))
ap = torch.arange(1, max_k+1).to(torch.float).view(1,-1)
gt_position = torch.arange(1,n_pos+1).view(1,-1)
def calculate_batch_ranking_metric(pos_pred, all_pred, pos_mask, k_list, report = {}):
    '''
    @input:
    - pos_pred: (B,R)
    - all_pred: (B,N)
    - mask: (B,R)
    - k_list: e.g. [1,5,10,20,50]
    - report: {"HR@1": 0, "P@1": 0, ...}
    '''
    if len(report) == 0:
        report = init_ranking_report(k_list)
    B,R = pos_pred.shape # (B,1)
    N = neg_pred.shape[1]
    
    pos_pred = pos_pred * pos_mask
    all_pred = torch.cat((pos_pred, neg_pred), dim = 1).view(B,-1) # (B,L)
    pos_length = torch.sum(pos_mask, dim = 1)

    rank = torch.sum(pos_pred.view(B,R,1) <= all_pred.view(B,1,R+N), dim = 2)
    rank = rank * pos_mask
    values, indices = torch.topk(all_pred, k = max_k, dim = 1)
    hit_map = (indices < R).to(torch.float)
    tp = torch.zeros_like(hit_map) # true positive
    tp[:,0] = hit_map[:,0]
    dcg = torch.zeros_like(hit_map) # DCG
    dcg[:,0] = hit_map[:,0]
    idcg = torch.zeros_like(hit_map)
    flip_mask = torch.flip(pos_mask, dims = [1])
    idcg[:,:flip_mask.shape[1]] = flip_mask
    idcg = idcg * b.view(1,-1)
    for i in range(1,max_k):
        tp[:,i] = tp[:,i-1] + hit_map[:,i]
        dcg[:,i] = dcg[:,i-1] + hit_map[:,i] * b[i]
        idcg[:,i] = idcg[:,i-1] + idcg[:,i]
    hr = tp.clone()
    hr[hr>0] = 1
    precision = (tp / ap)
    recall = (tp / pos_length.view(-1,1))
    f1 = (2*tp / (ap + pos_length.view(-1,1))) # 2TP / ((TP+FP) + (TP+FN))
    ndcg = (dcg / idcg)
    
    # mean rank
    report['MR'] += torch.sum(torch.sum(rank, dim = 1) / pos_length)
    # mean reciprocal rank
    mrr = torch.sum(pos_mask / (rank + 1e-6), dim = 1)
    report['MRR'] += torch.sum(mrr / pos_length)
    # hit rate
    hr = torch.sum(hr, dim = 0)
    # precision
    precision = torch.sum(precision, dim = 0)
    # recall
    recall = torch.sum(recall, dim = 0)
    # f1
    f1 = torch.sum(f1, dim = 0)
    # ndcg
    ndcg = torch.sum(ndcg, dim = 0)
    # auc
    rank[rank == 0] = R+N+1
    sorted_rank, _ = torch.sort(rank, dim = 1)
    level_width = sorted_rank - gt_position
    level_width = level_width * flip_mask
    auc = torch.sum(level_width, dim = 1) / pos_length
    auc = auc / N
    report['AUC'] += torch.sum(1 - auc)
    
    for k in k_list:
        report[f'HR@{k}'] += hr[k-1]
        report[f'P@{k}'] += precision[k-1]
        report[f'RECALL@{k}'] += recall[k-1]
        report[f'F1@{k}'] += f1[k-1]
        report[f'NDCG@{k}'] += ndcg[k-1]
    return report, B

In [None]:
report,B = calculate_batch_ranking_metric(pos_pred, neg_pred, pos_mask, k_list)
for k,v in report.items():
    print(k, (v/B).cpu().numpy())