In [1]:
import os
import torch
from torch.nn.utils.rnn import pad_sequence
import numpy as np
import pandas as pd
from tqdm import tqdm, trange
from collections import *
import copy
from functools import lru_cache

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
valid_data_path = "/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/all_task_1_valid_sessions.csv"
roberta_pred_path = "/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_valid_150_with_score.parquet"
sasrec_pred_path = "/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_valid_150_with_score.parquet"
graph_pred_path = "/root/autodl-tmp/huangxu/Amazon-KDDCUP-23/co-occurrence_graph/graph_valid_150_with_score.parquet"

In [3]:
def normalization(data):
    print("Max=", np.max(data, axis=0), " Min=", np.min(data, axis=0))
    _range = np.max(data, axis=0) - np.min(data, axis=0)
    return (data - np.min(data, axis=0)) / _range

def standardization(data):
    mu = np.mean(data, axis=0)
    sigma = np.std(data, axis=0)
    print("Mean={}, Sigma={}".format(mu, sigma))
    return (data - mu) / sigma

def softmax(data, axis):
    _exp = np.exp(data)
    return data / _exp.sum(axis=axis, keepdims=True)

In [4]:
@lru_cache(maxsize=1)
def read_valid_data():
    return pd.read_csv(valid_data_path)

@lru_cache(maxsize=1)
def read_roberta_pred():
    return pd.read_parquet(roberta_pred_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_pred():
    return pd.read_parquet(sasrec_pred_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_graph_pred():
    return pd.read_parquet(graph_pred_path, engine='pyarrow')

In [2]:
# softmax normalization
def softmax_norm(pred_df, log=False):
    new_pred_df = copy.deepcopy(pred_df)
    softmax_fn = torch.nn.Softmax(dim=-1)
    if log:
        all_scores = [np.log(pred_df.iloc[i]['scores'] + 1).tolist() for i in tqdm(range(pred_df.shape[0]))] # [N, 150]
    else:  
        all_scores = [pred_df.iloc[i]['scores'].tolist() for i in tqdm(range(pred_df.shape[0]))] # [N, 150]
    norm_score = softmax_fn(torch.tensor(all_scores, dtype=torch.float)).numpy()
    norm_score = [x for x in norm_score]
    new_pred_df['scores'] = norm_score
    return new_pred_df

In [6]:
# log and min max normalization
def log_min_max_norm(pred_df):
    new_pred_df = copy.deepcopy(pred_df)
    all_scores = [pred_df.iloc[i]['scores'] for i in tqdm(range(pred_df.shape[0]))] # [N, 150]
    norm_scores = []
    for i in tqdm(range(len(pred_df))):
        cur_scores = np.log(all_scores[i] + 1)
        cur_scores = (cur_scores - cur_scores.min()) / (cur_scores.max() - cur_scores.min())
        norm_scores.append(cur_scores)
    new_pred_df['scores'] = norm_scores
    return new_pred_df 

In [7]:
# @numba.jit(parallel=True)
def merge_all_scores(predictions: list, ground_truth=None):
    N = predictions[0].shape[0]
    all_scores = []
    all_labels = []
    all_cands = []
    for i in tqdm(range(N)):
        scores = []
        labels = []
        truth = ground_truth.iloc[i] if ground_truth is not None else None
        cur_predictions = [p.iloc[i] for p in predictions]
        
        # get all valid candidates that appears in all three sets
        rec_sets = [set(p['next_item_prediction']) for p in cur_predictions] 
        all_valid_candidates = rec_sets[0] # appear in three sets 
        for rec_set in rec_sets[1:]:
            all_valid_candidates = all_valid_candidates.intersection(rec_set)
        all_valid_candidates = list(all_valid_candidates)
        if (len(all_valid_candidates) == 0) or (truth not in all_valid_candidates): # ground truth must be in the combined candidates set 
            continue

        id2score = [Counter(dict(zip(p['next_item_prediction'], p['scores']))) for p in cur_predictions]

        for cand in all_valid_candidates:
            scores.append([d[cand] for d in id2score])
            if ground_truth is not None:
                if cand == truth:
                    labels.append(1)
                else:
                    labels.append(0)
        all_scores.append(scores)
        all_labels.append(labels)
        all_cands.append(all_valid_candidates)
    return all_scores, all_labels, all_cands

In [75]:
# @numba.jit(parallel=True)
def merge_all_scores_union(predictions: list, ground_truth=None):
    N = predictions[0].shape[0]
    all_scores = []
    all_labels = []
    all_cands = []
    for i in tqdm(range(N)):
        scores = []
        labels = []
        truth = ground_truth.iloc[i] if ground_truth is not None else None
        cur_predictions = [p.iloc[i] for p in predictions]
        
        # get all valid candidates that appears in all three sets
        rec_sets = [set(p['next_item_prediction']) for p in cur_predictions] 
        all_valid_candidates = rec_sets[0] # appear in three sets 
        for rec_set in rec_sets[1:]:
            all_valid_candidates = all_valid_candidates.union(rec_set)
        all_valid_candidates = list(all_valid_candidates)
        if (len(all_valid_candidates) == 0) or (truth not in all_valid_candidates): # ground truth must be in the combined candidates set 
            continue

        id2score = [Counter(dict(zip(p['next_item_prediction'], p['scores']))) for p in cur_predictions]

        for cand in all_valid_candidates:
            scores.append([d[cand] for d in id2score])
            if ground_truth is not None:
                if cand == truth:
                    labels.append(1)
                else:
                    labels.append(0)
        
        # padding
        num_padding_candidates = len(predictions) * len(predictions[0].iloc[0]['next_item_prediction']) - len(all_valid_candidates)
        for i in range(num_padding_candidates):
            scores.append([0.0, 0.0, 0.0])
            labels.append(-1)
            
        all_scores.append(scores)
        all_labels.append(labels)
        all_cands.append(all_valid_candidates)
    return all_scores, all_labels, all_cands

In [64]:
valid_data = read_valid_data()
roberta_pred = read_roberta_pred()
sasrec_pred = read_sasrec_pred()
graph_pred = read_graph_pred()

In [65]:
assert len(valid_data) == len(roberta_pred) and len(valid_data) == len(sasrec_pred) and len(valid_data) == len(graph_pred)

In [12]:
common_rate = {'bert-sasrec':[], 'sasrec-graph':[], 'bert-graph':[], 'all':[]}

for i in trange(valid_data.shape[0]):
    roberta_rec = set(roberta_pred.iloc[i]['next_item_prediction'])
    sasrec_rec = set(sasrec_pred.iloc[i]['next_item_prediction'])
    graph_rec = set(graph_pred.iloc[i]['next_item_prediction'])
    common = roberta_rec.intersection(sasrec_rec)
    common_rate['bert-sasrec'].append(len(common)/len(sasrec_rec))
    common = graph_rec.intersection(sasrec_rec)
    common_rate['sasrec-graph'].append(len(common)/len(sasrec_rec))
    common = roberta_rec.intersection(graph_rec)
    common_rate['bert-graph'].append(len(common)/len(sasrec_rec))
    common = roberta_rec.intersection(sasrec_rec.intersection(graph_rec))
    common_rate['all'].append(len(common)/len(sasrec_rec))

100%|██████████| 361581/361581 [00:57<00:00, 6326.85it/s]


In [13]:
rate = {}
for k, v in common_rate.items():
    rate[k] = np.mean(v)
print(rate)

{'bert-sasrec': 0.2713753949092089, 'sasrec-graph': 0.19065977839912304, 'bert-graph': 0.1365519390307197, 'all': 0.1117052057491959}


In [66]:
normlized_sasrec_pred = softmax_norm(sasrec_pred)
normlized_roberta_pred = softmax_norm(roberta_pred)

100%|██████████| 361581/361581 [00:32<00:00, 11246.02it/s]
100%|██████████| 361581/361581 [00:35<00:00, 10156.19it/s]


In [162]:
normlized_graph_pred = softmax_norm(graph_pred)

100%|██████████| 361581/361581 [00:27<00:00, 13250.01it/s]


In [147]:
graph_pred.head(20)

Unnamed: 0,locale,next_item_prediction,scores
0,UK,"[B06XG1LZ6Z, B06XGDZVZR, B06XGD9VLV, B076PN1SK...","[43, 41, 37, 22, 20, 15, 15, 13, 10, 9, 8, 7, ..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B01H6MF6Z...","[503, 51, 44, 42, 29, 28, 26, 20, 20, 19, 19, ..."
2,UK,"[B00L529BAC, B01EV58VX2, B003TJATC8, B01M6625M...","[2, 2, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, ..."
3,UK,"[0008532222, 1801314918, 024157563X, 024156343...","[2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P6DKN...","[13, 5, 5, 4, 4, 4, 3, 2, 2, 1, 1, 1, 0, 0, 0,..."
5,DE,"[B0BD48G63Q, B0BD3DGNT9, B08CRVG7BB, B0953XQY4...","[35, 6, 5, 4, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2,..."
6,DE,"[B07BMZYYFZ, B07Q82LRDK, B0045DNZ9Q, B08BR288Z...","[3, 2, 2, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
7,UK,"[B08FBZ8QSL, B07W6JJ253, B013SL2712, B07W6JP97...","[31, 18, 18, 16, 16, 10, 9, 9, 8, 8, 7, 7, 7, ..."
8,JP,"[B0BFPGHSYX, B00B57A5IY, B00LE7TO0K, B086LGXKY...","[19, 15, 15, 12, 11, 10, 9, 8, 8, 8, 8, 7, 7, ..."
9,DE,"[B07KQHHYQC, B07MH3K3S8, B07YLZ67Q8, B07KQJH4P...","[11, 7, 6, 6, 5, 5, 4, 3, 3, 3, 2, 2, 2, 2, 2,..."


In [163]:
normlized_graph_pred.head(20)

Unnamed: 0,locale,next_item_prediction,scores
0,UK,"[B06XG1LZ6Z, B06XGDZVZR, B06XGD9VLV, B076PN1SK...","[0.87887824, 0.11894324, 0.0021785214, 6.66414..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B01H6MF6Z...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,UK,"[B00L529BAC, B01EV58VX2, B003TJATC8, B01M6625M...","[0.04227002, 0.04227002, 0.015550272, 0.015550..."
3,UK,"[0008532222, 1801314918, 024157563X, 024156343...","[0.03528044, 0.03528044, 0.03528044, 0.0352804..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P6DKN...","[0.99855167, 0.00033497676, 0.00033497676, 0.0..."
5,DE,"[B0BD48G63Q, B0BD3DGNT9, B08CRVG7BB, B0953XQY4...","[1.0, 2.5436657e-13, 9.357623e-14, 3.442477e-1..."
6,DE,"[B07BMZYYFZ, B07Q82LRDK, B0045DNZ9Q, B08BR288Z...","[0.10839459, 0.039876144, 0.039876144, 0.01466..."
7,UK,"[B08FBZ8QSL, B07W6JJ253, B013SL2712, B07W6JP97...","[0.99999475, 2.2603176e-06, 2.2603176e-06, 3.0..."
8,JP,"[B0BFPGHSYX, B00B57A5IY, B00LE7TO0K, B086LGXKY...","[0.96326417, 0.0176428, 0.0176428, 0.000878383..."
9,DE,"[B07KQHHYQC, B07MH3K3S8, B07YLZ67Q8, B07KQJH4P...","[0.95961535, 0.017575968, 0.0064658374, 0.0064..."


In [164]:
normlized_sasrec_pred.head(2), normlized_roberta_pred.head(2), normlized_graph_pred.head(2)

(  locale                               next_item_prediction sess_id  \
 0     UK  [B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...       0   
 1     JP  [B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...       1   
 
                                               scores  
 0  [0.2711737, 0.16762927, 0.16571583, 0.12527893...  
 1  [0.910162, 0.019133814, 0.0131163215, 0.010162...  ,
                                 next_item_prediction  \
 0  [B096ZT4DK4, B09XCGNN6F, B09XCHDLYR, B06XGD9VL...   
 1  [B08ZHJKF28, B09LCPT9DQ, B09MRYK5CV, B09G9YRX1...   
 
                                               scores  
 0  [0.01427838, 0.013319071, 0.013319071, 0.01237...  
 1  [0.0713702, 0.03977809, 0.035806235, 0.0334270...  ,
   locale                               next_item_prediction  \
 0     UK  [B06XG1LZ6Z, B06XGDZVZR, B06XGD9VLV, B076PN1SK...   
 1     JP  [B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B01H6MF6Z...   
 
                                               scores  
 0  [0.87887824, 0.11

In [165]:
all_scores, all_labels, _ = merge_all_scores_union([normlized_sasrec_pred, normlized_roberta_pred, normlized_graph_pred], valid_data['next_item'])

100%|██████████| 361581/361581 [15:32<00:00, 387.59it/s]  


In [166]:
all_scores = np.array(all_scores, dtype=np.float32)
all_labels = np.array(all_labels, dtype=np.float32)

In [96]:
np.save('scores_450_union.npy', all_scores)
np.save('labels_450_union.npy', all_labels)

In [173]:
all_scores = np.load('scores_450_union_all_softmax.npy')
all_labels = np.load('labels_450_union_all_softmax.npy')

In [174]:
type(all_scores[0][0][0])

numpy.float32

In [175]:
class ScoreDataset(torch.utils.data.Dataset):
    def __init__(self, scores, labels):
        self.scores = scores
        self.labels = labels
        print(len(scores), len(labels))

    def __getitem__(self, index):
        return {'scores' : self.scores[index], 'labels' : self.labels[index]}

    def __len__(self):
        return len(self.scores)


In [184]:
class EmsembleWeight(torch.nn.Module):
    def __init__(self, n_models):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.ones(n_models, dtype=torch.float), requires_grad=True)
        self.relu_fn = torch.nn.ReLU()

    def forward(self, scores):
        # scores : [B, L, 3], w : [3]
        w = torch.softmax(self.weights, dim=-1)
        weighted_score = torch.matmul(scores, w.view(-1,1))
        return weighted_score.squeeze()

    def cal_loss(self, batch):
        # batch['labels'] : [B, L]
        score = self.forward(batch['scores']) # [B, L]
        
        # pos score
        pos_score = score[batch['labels'] == 1] # [B]
        assert len(pos_score) == len(batch['scores']) 
        pos_score = pos_score.view(-1, 1) # [B, 1]

        # all score 
        score[batch['labels'] == -1] = -float('inf')

        delta = score - pos_score # [B, L]
        delta = self.relu_fn(delta + 0.15) # pos with big difference will be omited

        # num of candidates 
        num_can = (batch['labels'] != -1).sum(dim=-1) # [B]
        loss = (delta.sum(dim=-1) / num_can).mean()
        return loss

In [185]:
dataset = ScoreDataset(all_scores, all_labels)
loader = torch.utils.data.DataLoader(dataset, batch_size=2048, shuffle=True)

344846 344846


In [186]:
model = EmsembleWeight(3)
model = model.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [187]:
epochs = 50
model.train()
for e in range(epochs):
    e_loss = 0
    step = 0
    for batch in tqdm(loader, total=len(loader)):
        optimizer.zero_grad()
        batch = {k: v.cuda() for k,v in batch.items()}
        loss = model.cal_loss(batch)
        loss.backward()
        optimizer.step()

        e_loss += loss.data
        step += 1
    softmax_weight = torch.softmax(model.weights.data.detach(), dim=-1)
    print("Epoch: {}: loss: {:.5f}".format(e, e_loss / step))
    print("Current weight: {}".format(softmax_weight))

100%|██████████| 169/169 [00:07<00:00, 21.35it/s]


Epoch: 0: loss: 0.10061
Current weight: tensor([0.3617, 0.2716, 0.3667], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.22it/s]


Epoch: 1: loss: 0.10041
Current weight: tensor([0.3779, 0.2276, 0.3944], device='cuda:0')


100%|██████████| 169/169 [00:09<00:00, 18.49it/s]


Epoch: 2: loss: 0.10030
Current weight: tensor([0.3861, 0.1952, 0.4187], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.82it/s]


Epoch: 3: loss: 0.10026
Current weight: tensor([0.3901, 0.1711, 0.4388], device='cuda:0')


100%|██████████| 169/169 [00:07<00:00, 23.92it/s]


Epoch: 4: loss: 0.10020
Current weight: tensor([0.3941, 0.1516, 0.4542], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.69it/s]


Epoch: 5: loss: 0.10018
Current weight: tensor([0.3966, 0.1361, 0.4673], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.47it/s]


Epoch: 6: loss: 0.10015
Current weight: tensor([0.3986, 0.1234, 0.4779], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.21it/s]


Epoch: 7: loss: 0.10013
Current weight: tensor([0.3985, 0.1128, 0.4887], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.01it/s]


Epoch: 8: loss: 0.10011
Current weight: tensor([0.4012, 0.1033, 0.4955], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.85it/s]


Epoch: 9: loss: 0.10013
Current weight: tensor([0.4027, 0.0952, 0.5021], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.28it/s]


Epoch: 10: loss: 0.10010
Current weight: tensor([0.4051, 0.0882, 0.5067], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.22it/s]


Epoch: 11: loss: 0.10010
Current weight: tensor([0.4029, 0.0817, 0.5154], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.95it/s]


Epoch: 12: loss: 0.10009
Current weight: tensor([0.4054, 0.0761, 0.5185], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.16it/s]


Epoch: 13: loss: 0.10009
Current weight: tensor([0.4056, 0.0710, 0.5234], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.81it/s]


Epoch: 14: loss: 0.10008
Current weight: tensor([0.4082, 0.0663, 0.5255], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.87it/s]


Epoch: 15: loss: 0.10007
Current weight: tensor([0.4093, 0.0620, 0.5287], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.14it/s]


Epoch: 16: loss: 0.10007
Current weight: tensor([0.4100, 0.0581, 0.5318], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.66it/s]


Epoch: 17: loss: 0.10008
Current weight: tensor([0.4123, 0.0546, 0.5331], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.69it/s]


Epoch: 18: loss: 0.10008
Current weight: tensor([0.4106, 0.0513, 0.5380], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.13it/s]


Epoch: 19: loss: 0.10006
Current weight: tensor([0.4109, 0.0483, 0.5408], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.53it/s]


Epoch: 20: loss: 0.10006
Current weight: tensor([0.4145, 0.0456, 0.5399], device='cuda:0')


100%|██████████| 169/169 [00:12<00:00, 13.48it/s]


Epoch: 21: loss: 0.10006
Current weight: tensor([0.4130, 0.0430, 0.5439], device='cuda:0')


100%|██████████| 169/169 [00:05<00:00, 28.18it/s]


Epoch: 22: loss: 0.10005
Current weight: tensor([0.4149, 0.0407, 0.5444], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 28.13it/s]


Epoch: 23: loss: 0.10006
Current weight: tensor([0.4162, 0.0384, 0.5454], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.03it/s]


Epoch: 24: loss: 0.10005
Current weight: tensor([0.4153, 0.0364, 0.5483], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.29it/s]


Epoch: 25: loss: 0.10005
Current weight: tensor([0.4162, 0.0345, 0.5493], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.88it/s]


Epoch: 26: loss: 0.10005
Current weight: tensor([0.4151, 0.0327, 0.5522], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.83it/s]


Epoch: 27: loss: 0.10004
Current weight: tensor([0.4142, 0.0310, 0.5548], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.39it/s]


Epoch: 28: loss: 0.10005
Current weight: tensor([0.4157, 0.0294, 0.5549], device='cuda:0')


100%|██████████| 169/169 [00:08<00:00, 19.10it/s]


Epoch: 29: loss: 0.10005
Current weight: tensor([0.4199, 0.0279, 0.5522], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.46it/s]


Epoch: 30: loss: 0.10005
Current weight: tensor([0.4186, 0.0264, 0.5549], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.74it/s]


Epoch: 31: loss: 0.10004
Current weight: tensor([0.4176, 0.0251, 0.5573], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.12it/s]


Epoch: 32: loss: 0.10005
Current weight: tensor([0.4192, 0.0238, 0.5570], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.16it/s]


Epoch: 33: loss: 0.10005
Current weight: tensor([0.4192, 0.0226, 0.5582], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.13it/s]


Epoch: 34: loss: 0.10005
Current weight: tensor([0.4186, 0.0215, 0.5599], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.94it/s]


Epoch: 35: loss: 0.10004
Current weight: tensor([0.4204, 0.0204, 0.5591], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.41it/s]


Epoch: 36: loss: 0.10004
Current weight: tensor([0.4191, 0.0194, 0.5615], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.65it/s]


Epoch: 37: loss: 0.10002
Current weight: tensor([0.4194, 0.0185, 0.5621], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.99it/s]


Epoch: 38: loss: 0.10005
Current weight: tensor([0.4218, 0.0176, 0.5606], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.31it/s]


Epoch: 39: loss: 0.10003
Current weight: tensor([0.4195, 0.0167, 0.5638], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.34it/s]


Epoch: 40: loss: 0.10006
Current weight: tensor([0.4243, 0.0160, 0.5598], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.89it/s]


Epoch: 41: loss: 0.10005
Current weight: tensor([0.4224, 0.0152, 0.5624], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.30it/s]


Epoch: 42: loss: 0.10005
Current weight: tensor([0.4216, 0.0144, 0.5639], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.23it/s]


Epoch: 43: loss: 0.10005
Current weight: tensor([0.4208, 0.0137, 0.5654], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.48it/s]


Epoch: 44: loss: 0.10004
Current weight: tensor([0.4222, 0.0131, 0.5647], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.80it/s]


Epoch: 45: loss: 0.10003
Current weight: tensor([0.4249, 0.0125, 0.5627], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 26.73it/s]


Epoch: 46: loss: 0.10003
Current weight: tensor([0.4228, 0.0119, 0.5654], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 27.04it/s]


Epoch: 47: loss: 0.10003
Current weight: tensor([0.4240, 0.0113, 0.5647], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 25.99it/s]


Epoch: 48: loss: 0.10005
Current weight: tensor([0.4258, 0.0108, 0.5635], device='cuda:0')


100%|██████████| 169/169 [00:06<00:00, 24.62it/s]

Epoch: 49: loss: 0.10003
Current weight: tensor([0.4249, 0.0103, 0.5649], device='cuda:0')





# calculate metrics on valid_data

In [189]:
normlized_sasrec_pred

Unnamed: 0,locale,next_item_prediction,sess_id,scores
0,UK,"[B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...",0,"[0.2711737, 0.16762927, 0.16571583, 0.12527893..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...",1,"[0.910162, 0.019133814, 0.0131163215, 0.010162..."
2,UK,"[B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...",2,"[0.4977258, 0.111056164, 0.0858272, 0.02598772..."
3,UK,"[0241572614, 1406392979, 024157563X, 024147681...",3,"[0.17356576, 0.105796315, 0.09800567, 0.081133..."
4,JP,"[B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...",4,"[0.94184154, 0.03955462, 0.0068080365, 0.00285..."
...,...,...,...,...
361576,UK,"[B0050IG9DE, B0B7LVKNK8, B07ZCWWZSM, B00465F49...",361576,"[0.2831857, 0.028501466, 0.027289871, 0.026143..."
361577,JP,"[B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL...",361577,"[0.32079417, 0.2724555, 0.10274507, 0.07312795..."
361578,DE,"[B0BC38GHB4, B07KLCY8NF, B00MXZEMBI, B07K6LTLW...",361578,"[0.56181896, 0.2779604, 0.06445203, 0.06327815..."
361579,DE,"[B08RQR2NPB, B08RQDVX71, B08H8SYLMQ, B08H8TLK4...",361579,"[0.46686655, 0.2816499, 0.078525245, 0.0615999..."


In [190]:
normlized_graph_pred

Unnamed: 0,locale,next_item_prediction,scores
0,UK,"[B06XG1LZ6Z, B06XGDZVZR, B06XGD9VLV, B076PN1SK...","[0.87887824, 0.11894324, 0.0021785214, 6.66414..."
1,JP,"[B09LCPT9DQ, B09MRYK5CV, B092D5HM5S, B01H6MF6Z...","[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ..."
2,UK,"[B00L529BAC, B01EV58VX2, B003TJATC8, B01M6625M...","[0.04227002, 0.04227002, 0.015550272, 0.015550..."
3,UK,"[0008532222, 1801314918, 024157563X, 024156343...","[0.03528044, 0.03528044, 0.03528044, 0.0352804..."
4,JP,"[B0B6PF619D, B09BJF6N8K, B0B6P77ZRN, B0B6P6DKN...","[0.99855167, 0.00033497676, 0.00033497676, 0.0..."
...,...,...,...
361576,UK,"[B08F5D8T22, B0050IG9DE, B01A955L8G, B08NYPTF9...","[0.016600577, 0.016600577, 0.016600577, 0.0166..."
361577,JP,"[B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL...","[0.9999546, 4.539787e-05, 0.0, 0.0, 0.0, 0.0, ..."
361578,DE,"[B0BC38GHB4, B07KLCY8NF, B00MXZEMBI, B01FE96DM...","[0.18709013, 0.18709013, 0.025319895, 0.025319..."
361579,DE,"[B08RQDVX71, B08RQR2NPB, B08H8TLK4F, B07PY86YP...","[1.0, 4.1399375e-08, 4.1399375e-08, 1.7139084e..."


In [205]:
def cal_hit_and_mrr(ground_truth_list, candidates_list):
    hits, mrrs = [], []
    for i in tqdm(range(len(ground_truth_list))):
        ground_truth = ground_truth_list.iloc[i]
        candidates = candidates_list.iloc[i]
        hit, mrr = 0.0, 0.0
        for j in range(len(candidates)):
            if ground_truth == candidates[j]:
                hit = 1.0
                mrr = 1.0 / (j + 1)
                break
        hits.append(hit)
        mrrs.append(mrr)
    return np.array(hits).mean(), np.array(mrrs).mean()

In [15]:
# merge multi predictions 
def merge_multi_predictions(pred_df_list : list[pd.DataFrame], weights : list):
    new_pred_df = copy.deepcopy(pred_df_list[0])
    new_predictions = []
    score_counter = Counter()
    for i in tqdm(range(pred_df_list[0].shape[0])):
        score_counter.clear()
        for pred_df, w in zip(pred_df_list, weights):
            for item, score in zip(pred_df.iloc[i]['next_item_prediction'], pred_df.iloc[i]['scores']):
                score_counter[item] += w * score 
        new_pred, _ = zip(*score_counter.most_common(100))
        new_predictions.append(list(new_pred))
    new_pred_df['next_item_prediction'] = new_predictions
    
    return new_pred_df
        

In [213]:
sasrec_graph_prediction = merge_two_prediction([normlized_sasrec_pred, normlized_graph_pred], [0.43, 0.57])
sasrec_graph_prediction.head(5)

  1%|          | 2197/361581 [01:02<2:50:01, 35.23it/s]


KeyboardInterrupt: 

In [219]:
normlized_sasrec_pred, normlized_graph_pred

(       locale                               next_item_prediction sess_id  \
 0          UK  [B06XG1LZ6Z, B06XGD9VLV, B076PN1SKG, B01MYUDYP...       0   
 1          JP  [B09LCPT9DQ, B09MRYK5CV, B0BB5VQ1L8, B084T9C6W...       1   
 2          UK  [B09XBS6WCX, B01EYGW86Y, B01MDOBUCC, B01C5YK17...       2   
 3          UK  [0241572614, 1406392979, 024157563X, 024147681...       3   
 4          JP  [B0B6PF619D, B0B6P77ZRN, B0B6P2PCMP, B0B6NY4PN...       4   
 ...       ...                                                ...     ...   
 361576     UK  [B0050IG9DE, B0B7LVKNK8, B07ZCWWZSM, B00465F49...  361576   
 361577     JP  [B09B9V4PXC, B09BCM5NL1, B09XGRXXG3, B09XH1YGL...  361577   
 361578     DE  [B0BC38GHB4, B07KLCY8NF, B00MXZEMBI, B07K6LTLW...  361578   
 361579     DE  [B08RQR2NPB, B08RQDVX71, B08H8SYLMQ, B08H8TLK4...  361579   
 361580     DE  [B0095FMJE6, B016UZOCX4, B07T42KRCB, B07JHTY3V...  361580   
 
                                                    scores  
 0       [0.2

In [221]:
cal_hit_and_mrr(valid_data['next_item'], sasrec_graph_prediction['next_item_prediction'])

100%|██████████| 361581/361581 [00:16<00:00, 22542.10it/s]


(0.8617709448228751, 0.3649325409228786)

In [251]:
cal_hit_and_mrr(valid_data['next_item'], graph_pred['next_item_prediction'])

100%|██████████| 361581/361581 [00:16<00:00, 21690.42it/s]


(0.9464297073131608, 0.4242726830543117)

# merge test sasrec and co-graph 

In [3]:
sasrec_pred_test_path = "/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/SASRec_Next/SASRec_test_150_with_score.parquet"
sasrec_pred_test_path_2 = "/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/predictions/SASRec_Next/kdd_cup_2023/three_locale_prediction_0416_2120.parquet"
roberta_pred_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/roberta/roberta_test_150_with_score.parquet'
graph_pred_test_path = "/root/autodl-tmp/huangxu/Amazon-KDDCUP-23/co-occurrence_graph/graph_test_150_with_score.parquet"


In [4]:
@lru_cache(maxsize=1)
def read_sasrec_pred_test():
    return pd.read_parquet(sasrec_pred_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_sasrec_pred_test_2():
    return pd.read_parquet(sasrec_pred_test_path_2, engine='pyarrow')

@lru_cache(maxsize=1)
def read_roberta_pred_test():
    return pd.read_parquet(roberta_pred_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_graph_pred_test():
    return pd.read_parquet(graph_pred_test_path, engine='pyarrow')

In [5]:
sasrec_pred_test = read_sasrec_pred_test_2()
roberta_pred_test = read_roberta_pred_test()
graph_pred_test = read_graph_pred_test()

In [6]:
sasrec_pred_test.head(2), len(sasrec_pred_test), len(sasrec_pred_test.iloc[0]['next_item_prediction'])

(  locale                               next_item_prediction  \
 0     DE  [B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...   
 1     DE  [B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...   
 
                                               scores  
 0  [19.93616485595703, 19.43468475341797, 18.6949...  
 1  [22.813570022583008, 18.525903701782227, 18.01...  ,
 316971,
 150)

In [7]:
sasrec_pred_test.rename(columns={'score' : 'scores'}, inplace=True)
sasrec_pred_test

Unnamed: 0,locale,next_item_prediction,scores
0,DE,"[B099NS1XPG, B07LG5T3V9, B099NR3X6D, B08QYYBTM...","[19.93616485595703, 19.43468475341797, 18.6949..."
1,DE,"[B004ZXMV4Q, B08BZCKDKQ, B095TQTZXY, B010MJNUZ...","[22.813570022583008, 18.525903701782227, 18.01..."
2,DE,"[B0B5QNFWJ1, B0BJF4KGCN, B099277D7Q, B0B5TFLBC...","[14.396807670593262, 14.095083236694336, 13.71..."
3,DE,"[395535086X, 3772476953, 3772477917, B0829LZFT...","[18.627986907958984, 18.095304489135742, 16.81..."
4,DE,"[B09J8SKX9G, B09J8V9RQQ, B09J8VPTTW, B09J8TWRV...","[19.423419952392578, 18.855445861816406, 18.45..."
...,...,...,...
316966,UK,"[B08X9L5RGD, B09G9YY2C9, B09MW64JGM, B07V5FL8G...","[20.52749252319336, 15.89537239074707, 15.7825..."
316967,UK,"[B0989BHLSY, B09895QPQF, B09CPNS7XV, B09L14HQF...","[17.313642501831055, 17.087045669555664, 16.69..."
316968,UK,"[B09HKZBNZH, B09HZSRJWW, B07PY1NG3X, B09HL141Q...","[22.398101806640625, 21.34939193725586, 19.468..."
316969,UK,"[B08FB464L7, B07TR5LQSL, B0BGDK1J1G, B00HEL380...","[16.046977996826172, 15.801373481750488, 15.09..."


In [8]:
roberta_pred_test.head(2), len(roberta_pred_test), len(roberta_pred_test.iloc[0]['next_item_prediction'])

(                                next_item_prediction  \
 0  [B07TV22X9M, B08Q391KS3, B07TV364MZ, B01H1R0K6...   
 1  [B004ZXMV4Q, B09P1XWJPS, B00R9RNWF2, B010MJNUZ...   
 
                                               scores  
 0  [269.218017578125, 268.97796630859375, 268.910...  
 1  [268.6519775390625, 267.3247985839844, 266.088...  ,
 316971,
 150)

In [9]:
graph_pred_test.head(2), len(graph_pred_test), len(graph_pred_test.iloc[0]['next_item_prediction'])

(  locale                               next_item_prediction  \
 0     DE  [B099NS1XPG, B099NR3X6D, B0B7S7LBMB, B0B53KBXR...   
 1     DE  [B004ZXMV4Q, B08H93ZRK9, B0BFJGXWDV, B0B1MPZWJ...   
 
                                               scores  
 0  [23, 12, 5, 4, 4, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2...  
 1  [14, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1,...  ,
 316971,
 150)

In [10]:
normlized_sasrec_pred_test = softmax_norm(sasrec_pred_test)
normlized_roberta_pred_test = softmax_norm(roberta_pred_test)
normlized_graph_pred_test = softmax_norm(graph_pred_test)

100%|██████████| 316971/316971 [00:12<00:00, 25076.77it/s]
100%|██████████| 316971/316971 [00:19<00:00, 16272.65it/s]
100%|██████████| 316971/316971 [00:09<00:00, 32299.80it/s]


In [25]:
merged_pred_test_df = merge_multi_predictions([normlized_sasrec_pred_test, normlized_graph_pred_test], [0.75, 0.25])

100%|██████████| 316971/316971 [04:41<00:00, 1124.11it/s]


In [26]:
merged_pred_test_df.drop(columns=['scores'], inplace=True)

In [27]:
merged_pred_test_df['next_item_prediction'].apply(len).describe()

count    316971.0
mean        100.0
std           0.0
min         100.0
25%         100.0
50%         100.0
75%         100.0
max         100.0
Name: next_item_prediction, dtype: float64

In [23]:
merged_pred_test_df.to_parquet('../predictions/merge_prediction/merge_co_graph_sasrec_softmax_3.parquet', engine='pyarrow')