In [8]:
import torch
import torch.nn as nn
import os
import json

In [9]:
def convert_rechorus_to_agent4rec(rechorus_path, save_path, config_path=None, target_dims=None):
    """
    转换ReChorus模型到Agent4Rec格式
    
    Args:
        rechorus_path: ReChorus模型路径
        save_path: 保存转换后的模型路径
        config_path: 保存配置文件的路径
        target_dims: 目标维度字典 {'user': user_dim, 'item': item_dim}
    """
    # 加载ReChorus模型
    rechorus_model = torch.load(rechorus_path)
    
    if target_dims:
        target_user_dim, target_item_dim = target_dims['user'], target_dims['item']
    else:
        target_user_dim, target_item_dim = 1000, 3292
    
    # 获取用户和物品embedding的key，需要根据具体模型的key进行调整，需要对应上一步的输出
    if 'encoder.embedding_dict' in str(rechorus_model):
        model_type = 'LightGCN'
        userid = 'encoder.embedding_dict.user_emb'
        itemid = 'encoder.embedding_dict.item_emb'
    else:
        model_type = 'BPRMF'
        userid = 'u_embeddings.weight'
        itemid = 'i_embeddings.weight'
    
    # 创建新的state_dict格式
    new_state_dict = {
        'epoch': 0,
        'state_dict': {}
    }
    
    # 转换并调整用户embedding维度
    if userid in rechorus_model:
        user_weights = rechorus_model[userid]
        if user_weights.size(0) > target_user_dim:
            new_state_dict['state_dict']['embed_user.weight'] = user_weights[:target_user_dim]
        else:
            # 如果目标维度更大，需要随机初始化额外的embedding
            new_weights = torch.randn(target_user_dim, user_weights.size(1))
            new_weights[:user_weights.size(0)] = user_weights
            new_state_dict['state_dict']['embed_user.weight'] = new_weights
            
    # 转换并调整物品embedding维度
    if itemid in rechorus_model:
        item_weights = rechorus_model[itemid]
        if item_weights.size(0) > target_item_dim:
            new_state_dict['state_dict']['embed_item.weight'] = item_weights[:target_item_dim]
        else:
            # 如果目标维度更大，需要随机初始化额外的embedding
            new_weights = torch.randn(target_item_dim, item_weights.size(1))
            new_weights[:item_weights.size(0)] = item_weights
            new_state_dict['state_dict']['embed_item.weight'] = new_weights
    
    # 保存转换后的模型
    torch.save(new_state_dict, save_path)
    
    # 生成配置文件
    if config_path:
        config = {
            "vis": -1,
            "seed": 0,
            "clear_checkpoints": True,
            "candidate": False,
            "test_only": False,
            "data_path": "../datasets/",
            "dataset": "ml-1m",
            "embed_size": 64,
            "batch_size": 256,
            "lr": 0.001,
            "regs": 1e-8,
            "epoch": 200,
            "Ks": "5,10,20,50",
            "verbose": 5,
            "saveID": model_type.lower(),
            "patience": 10,
            "checkpoint": "./",
            "modeltype": model_type,
            "IPStype": "cn",
            "infonce": 0,
            "cuda": 0,
            "n_layers": 3 if model_type == 'LightGCN' else 0,
            "neg_sample": 1,
            "num_workers": 0,
            "train_norm": False,
            "pred_norm": False,
            "nodrop": True
        }
        
        with open(config_path, 'w') as f:
            json.dump(config, f, indent=2)


In [10]:
# lightGCN
rechorus_path = "ReChorus/model/LightGCN/LightGCN__MovieLens_1M/ML_1MTOPK/__0__lr=0.001__l2=1e-08__emb_size=64__n_layers=3__batch_size=256.pt"
config_path = "Agent4Rec/recommenders/weights/ml-1m/LightGCN/converted_models/args.txt"
target_dims = {'user': 1000, 'item': 3292}

log_path = "ReChorus/log/LightGCN/LightGCN__MovieLens_1M/ML_1MTOPK/__0__lr=0.001__l2=1e-08__emb_size=64__n_layers=3__batch_size=256.txt"
epoch = 0
with open(log_path, 'r') as f:
    temp = f.read()
    pos = temp.find('Best Iter(dev)=')
    epoch = int(temp[pos+15:pos+21])
    print(epoch)
save_path = f"Agent4Rec/recommenders/weights/ml-1m/LightGCN/converted_models/epoch={epoch}.checkpoint.pth.tar"

convert_rechorus_to_agent4rec(
    rechorus_path=rechorus_path,
    save_path=save_path,
    config_path=config_path,
    target_dims=target_dims
)

24


In [11]:
# BPRMF
rechorus_path = "ReChorus/model/BPRMF/BPRMF__MovieLens_1M/ML_1MTOPK/__0__lr=0.001__l2=1e-08__emb_size=64__batch_size=256.pt"
config_path = "Agent4Rec/recommenders/weights/ml-1m/MF/converted_models/args.txt"
target_dims = {'user': 1000, 'item': 3292}

log_path = "ReChorus/log/BPRMF/BPRMF__MovieLens_1M/ML_1MTOPK/__0__lr=0.001__l2=1e-08__emb_size=64__batch_size=256.txt"
epoch = 0
with open(log_path, 'r') as f:
    temp = f.read()
    pos = temp.find('Best Iter(dev)=')
    epoch = int(temp[pos+15:pos+21])
    print(epoch)
save_path = f"Agent4Rec/recommenders/weights/ml-1m/MF/converted_models/epoch={epoch}.checkpoint.pth.tar"

convert_rechorus_to_agent4rec(
    rechorus_path=rechorus_path,
    save_path=save_path,
    config_path=config_path,
    target_dims=target_dims
)

12
