In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from utils.reparam_module import ReparamModule

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from utils import *
config = {
    'dataset': 'amazon-toys-noise-50',
    'model': 'MetaModel7',
}
sub_model = 'SASRec'

config = load_config(config)
dataset_list = prepare_datasets(config)
config['model']['sub_model'] = sub_model
model = prepare_model(config, dataset_list)
model._init_model(dataset_list[0])
device = model.device
state_dict_path = {
    'SASRec': 'saved/MetaModel7/amazon-toys-noise-50/2024-01-26-00-54-26-865661.ckpt',
    'GRU4Rec': 'saved/MetaModel7/amazon-toys-noise-50/2024-01-27-19-11-53-738587.ckpt',
}
# model.load_checkpoint('saved/MetaModel7/amazon-toys-noise-50/2024-01-26-00-54-26-865661.ckpt') # SASRec on toys
model.load_checkpoint(state_dict_path[sub_model]) # GRU4Rec on toys

In [3]:
loader = model.dataset_list[0].get_loader(shuffle=False)

In [4]:
loss_weight_list = []
logits_list = []
model.eval()
for batch in loader:
    query = model.sub_model.forward(batch, need_pooling=False)
    logits = model.meta_module(query)
    logits_list.append(logits.detach().cpu())
    weight = F.gumbel_softmax(logits, torch.clip(model.tau, min=1), dim=-1)[..., 0]
    mask = batch['user_id'] == 0
    weight = weight.masked_fill(mask.unsqueeze(-1), 1)
    pad_mask = batch['item_id'] == 0
    weight = weight.masked_fill(pad_mask, 0)
    loss_weight_list.append(weight.detach().cpu())
loss_weight = torch.cat(loss_weight_list)
logits = torch.cat(logits_list)

In [7]:
torch.save(loss_weight, f'paper/loss_weight_{sub_model}_toys.pth')
torch.save([logits, model.tau.item()], f'paper/logits_{sub_model}_toys.pth')

# Analyze

In [None]:
mask = (model.dataset_list[0].data[0] == 1).cpu() # select meaningful points to compare
selected_idx = torch.randperm(mask.sum().item())[:50]

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 创建一个50x100的随机矩阵
matrix = loss_weight[mask][selected_idx]


# 画热力图
plt.imshow(matrix, cmap='hot', interpolation='nearest')
plt.colorbar()

# 添加标题和坐标轴标签
plt.title('Heatmap')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')

# 显示图形
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# 创建一个50x100的随机矩阵
matrix = debug[mask][selected_idx]


# 画热力图
plt.imshow(matrix, cmap='hot', interpolation='nearest')
plt.colorbar()

# 添加标题和坐标轴标签
plt.title('Heatmap')
plt.xlabel('X-axis')
plt.ylabel('Y-axis')

# 显示图形
plt.show()