In [1]:
import argparse
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from dataset import MyDataset



def get_args():
    parser = argparse.ArgumentParser()

    # Train params
    parser.add_argument('--batch_size', default=128, type=int)
    parser.add_argument('--lr', default=0.001, type=float)
    parser.add_argument('--maxlen', default=101, type=int)

    # Baseline Model construction
    parser.add_argument('--hidden_units', default=32, type=int)
    parser.add_argument('--num_blocks', default=1, type=int)
    parser.add_argument('--num_epochs', default=3, type=int)
    parser.add_argument('--num_heads', default=1, type=int)
    parser.add_argument('--dropout_rate', default=0.2, type=float)
    parser.add_argument('--l2_emb', default=0.0, type=float)
    parser.add_argument('--device', default='cpu', type=str)
    parser.add_argument('--inference_only', action='store_true')
    parser.add_argument('--state_dict_path', default=None, type=str)
    parser.add_argument('--norm_first', action='store_true')

    parser.add_argument('-f', '--file', default=None)  # 兼容Jupyter的自动参数
    

    # MMemb Feature ID
    parser.add_argument('--mm_emb_id', nargs='+', default=['85','86'], type=str, choices=[str(s) for s in range(81, 87)]) #多模态

    args = parser.parse_args()

    return args
# global dataset
data_path = '/Users/huang/Desktop/TX/dataset/TencentGR_1k'

args = get_args()
dataset = MyDataset(data_path, args)
train_dataset, valid_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])
train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=0, collate_fn=dataset.collate_fn
)
valid_loader = DataLoader(
    valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=0, collate_fn=dataset.collate_fn
)
usernum, itemnum = dataset.usernum, dataset.itemnum
feat_statistics, feat_types = dataset.feat_statistics, dataset.feature_types  #特征跟特征类型


Loading mm_emb:  50%|█████     | 1/2 [00:29<00:29, 29.09s/it]

Loaded #85 mm_emb


Loading mm_emb: 100%|██████████| 2/2 [00:36<00:00, 18.34s/it]

Loaded #86 mm_emb





In [2]:
args.device = 'cpu'
from model import BaselineModel
model = BaselineModel(usernum, itemnum, feat_statistics, feat_types, args).to(args.device)


for name, param in model.named_parameters():
    try:
        torch.nn.init.xavier_normal_(param.data)
    except Exception:
        pass

model.pos_emb.weight.data[0, :] = 0  #第0行置0
model.item_emb.weight.data[0, :] = 0
model.user_emb.weight.data[0, :] = 0

for k in model.sparse_emb:
    model.sparse_emb[k].weight.data[0, :] = 0

epoch_start_idx = 1

if args.state_dict_path is not None:
    try:
        model.load_state_dict(torch.load(args.state_dict_path, map_location=torch.device(args.device)))
        tail = args.state_dict_path[args.state_dict_path.find('epoch=') + 6 :]
        epoch_start_idx = int(tail[: tail.find('.')]) + 1
    except:
        print('failed loading state_dicts, pls check file path: ', end="")
        print(args.state_dict_path)
        raise RuntimeError('failed loading state_dicts, pls check file path!')

bce_criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.98))

best_val_ndcg, best_val_hr = 0.0, 0.0
best_test_ndcg, best_test_hr = 0.0, 0.0
T = 0.0
t0 = time.time()
global_step = 0
print("Start training")

for epoch in range(epoch_start_idx, args.num_epochs + 1):
    model.train()
    if args.inference_only:
        break
    
    for step, batch in tqdm(enumerate(train_loader), total=len(train_loader)):
        seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat = batch
        seq = seq.to(args.device)
        pos = pos.to(args.device)
        neg = neg.to(args.device)
        pos_logits, neg_logits, pos_pred = model(
            seq, pos, neg, token_type, next_token_type, next_action_type, seq_feat, pos_feat, neg_feat
        )
        

Start training


  0%|          | 0/7 [00:00<?, ?it/s]


ValueError: could not broadcast input array from shape (1024,) into shape (3584,)