### 数据准备

In [293]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

In [294]:
user_features = pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/features/user_features.csv')
movie_features = pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/features/movie_features.csv')
ratings=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/ratings.csv')
movies=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/movies.csv')
users=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/users.csv')

### 查看数据情况

In [285]:
user_features.head()

Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,favorite_genre,num_liked_genres
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.045028,0.16967,0.226609,0.204426,0.036571,0.0,0.39829,0.093389,0.0,0.0,0.479172,0.0,0.382257,0.297663,0.406715,0.418157,0.0,Musical,13
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.081828,0.0,0.0,0.117219,0.063696,0.0,0.447701,0.010214,0.010318,0.02085,0.0,0.031933,0.263799,0.252818,0.613924,0.745944,0.51241,War,14
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.249533,0.039769,0.041393,0.431545,0.0,0.0,0.200254,0.062209,0.0,0.099299,0.036615,0.037954,0.196962,0.291509,0.336156,0.195883,0.71377,Western,15
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.152565,0.0,0.029868,0.0,0.03076,0.0,0.19023,0.077727,0.0,0.126004,0.0,0.0,0.095535,0.472931,0.380836,0.43579,0.462981,Sci-Fi,12
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.02803,0.012816,0.019472,0.185339,0.08525,0.026617,0.473922,0.0,0.02588,0.088538,0.029117,0.07995,0.325581,0.240121,0.817461,0.629844,0.240616,Thriller,17


In [286]:
movie_features.head()

Unnamed: 0,movie_id,title,genres,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,movie_mean_rating,movie_rating_std,movie_rating_count,genre_purity
0,1,Toy Story (1995),Animation|Children's|Comedy,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,4.146846,0.852349,2077,0.333333
1,2,Jumanji (1995),Adventure|Children's|Fantasy,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,3.201141,0.983172,701,0.333333
2,3,Grumpier Old Men (1995),Comedy|Romance,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,3.016736,1.071712,478,0.5
3,4,Waiting to Exhale (1995),Comedy|Drama,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,2.729412,1.013381,170,0.5
4,5,Father of the Bride Part II (1995),Comedy,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,3.006757,1.025086,296,1.0


In [239]:
ratings.head()

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


### 数据预处理

#### 1. 处理movie_features

In [295]:
print("用户评分数据情况:")
print(ratings['rating'].value_counts())

用户评分数据情况:
rating
4    348971
3    261197
5    226310
2    107557
1     56174
Name: count, dtype: int64


In [296]:
# 处理电影的标题
import re
pattern = re.compile(r'^(.*)\((\d+)\)$')

# 1. 去除掉特殊字符
title_map = {val:pattern.match(val).group(1) for ii,val in enumerate(set(movie_features['title']))}
movie_features['title'] = movie_features['title'].map(title_map)

# 2. 转换为数字字典
title_set=set()
for val in movie_features['title'].str.split():
    title_set.update(val)

title_set.add('<PAD>')
title2int={val: i for i, val in enumerate(title_set)}

# 3. 将电影title转成等长列表，长度为15
title_count=15
# 遍历movies_df['title']列中所有唯一的电影标题，对于每个电影标题，将其按照空格分割成单词，然后通过title2int字典将每个单词转换为对应的整数
title_map = {val: [title2int[row] for row in val.split()] for val in movie_features['title']}
for key in title_map:
    for cnt in range(title_count-len(title_map[key])):
        title_map[key].insert(len(title_map[key])+cnt, title2int['<PAD>'])
movie_features['title']=movie_features['title'].map(title_map)

In [297]:
pd.set_option('display.max_columns', None)
movie_features_numeric=movie_features.drop(['genres'], axis=1)
print("movie_features维度:", movie_features.shape)
movie_features.head()

movie_features维度: (3706, 25)


Unnamed: 0,movie_id,title,genres,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,movie_mean_rating,movie_rating_std,movie_rating_count,genre_purity
0,1,"[2374, 1922, 3100, 3100, 3100, 3100, 3100, 310...",Animation|Children's|Comedy,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,4.146846,0.852349,2077,0.333333
1,2,"[2664, 3100, 3100, 3100, 3100, 3100, 3100, 310...",Adventure|Children's|Fantasy,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,3.201141,0.983172,701,0.333333
2,3,"[2932, 3394, 2536, 3100, 3100, 3100, 3100, 310...",Comedy|Romance,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,3.016736,1.071712,478,0.5
3,4,"[4466, 867, 3609, 3100, 3100, 3100, 3100, 3100...",Comedy|Drama,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,2.729412,1.013381,170,0.5
4,5,"[4660, 4652, 1629, 4523, 1063, 3358, 3100, 310...",Comedy,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,3.006757,1.025086,296,1.0


#### 2. 处理user_features

In [290]:
user_features.head()

Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,favorite_genre,num_liked_genres
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.045028,0.16967,0.226609,0.204426,0.036571,0.0,0.39829,0.093389,0.0,0.0,0.479172,0.0,0.382257,0.297663,0.406715,0.418157,0.0,Musical,13
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.081828,0.0,0.0,0.117219,0.063696,0.0,0.447701,0.010214,0.010318,0.02085,0.0,0.031933,0.263799,0.252818,0.613924,0.745944,0.51241,War,14
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.249533,0.039769,0.041393,0.431545,0.0,0.0,0.200254,0.062209,0.0,0.099299,0.036615,0.037954,0.196962,0.291509,0.336156,0.195883,0.71377,Western,15
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.152565,0.0,0.029868,0.0,0.03076,0.0,0.19023,0.077727,0.0,0.126004,0.0,0.0,0.095535,0.472931,0.380836,0.43579,0.462981,Sci-Fi,12
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.02803,0.012816,0.019472,0.185339,0.08525,0.026617,0.473922,0.0,0.02588,0.088538,0.029117,0.07995,0.325581,0.240121,0.817461,0.629844,0.240616,Thriller,17


In [298]:
# 处理用户最喜欢的电影类型
unique_genres = movies['genres'].str.get_dummies(sep='|').columns
genre_to_idx = {genre: idx for idx, genre in enumerate(unique_genres)}

user_features['favorite_genre_idx'] = user_features['favorite_genre'].map(
    lambda x: genre_to_idx.get(x, 0)  # 未知类型默认为0
)

In [299]:
print("user_features维度:", user_features.shape)
user_features_numeric=user_features.drop(['favorite_genre'], axis=1)
user_features_numeric.head()

user_features维度: (6040, 29)


Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,num_liked_genres,favorite_genre_idx
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.045028,0.16967,0.226609,0.204426,0.036571,0.0,0.39829,0.093389,0.0,0.0,0.479172,0.0,0.382257,0.297663,0.406715,0.418157,0.0,13,11
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.081828,0.0,0.0,0.117219,0.063696,0.0,0.447701,0.010214,0.010318,0.02085,0.0,0.031933,0.263799,0.252818,0.613924,0.745944,0.51241,14,16
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.249533,0.039769,0.041393,0.431545,0.0,0.0,0.200254,0.062209,0.0,0.099299,0.036615,0.037954,0.196962,0.291509,0.336156,0.195883,0.71377,15,17
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.152565,0.0,0.029868,0.0,0.03076,0.0,0.19023,0.077727,0.0,0.126004,0.0,0.0,0.095535,0.472931,0.380836,0.43579,0.462981,12,14
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.02803,0.012816,0.019472,0.185339,0.08525,0.026617,0.473922,0.0,0.02588,0.088538,0.029117,0.07995,0.325581,0.240121,0.817461,0.629844,0.240616,17,15


#### 3. 查看数据情况

In [300]:
# 获取电影类型数
genres=movies['genres'].str.get_dummies(sep='|')
genre_columns=genres.columns
print("电影类型数为:", len(genre_columns))
print("电影数量有:", len(movie_features))
print("总的评价数有:", len(ratings))
print("用户数量为:", len(user_features))

# 不包含user_id, favorite_genre
user_numeric_cols = [
    'mean_rating', 'rating_std', 'rating_count', 'rating_min', 'rating_max',
    'rating_strictness', 'rating_variability', 'num_liked_genres', 'favorite_genre_idx'
] + [col for col in user_features.columns if col in genre_columns]
print("选定的用户数值型特征列数量为:", len(user_numeric_cols))

# 不包含movie_id, genres
movie_cols = [
    'title', 'movie_mean_rating', 'movie_rating_std', 'movie_rating_count', 'genre_purity'
] + [col for col in movie_features.columns if col in genre_columns]
print("选定的电影特征数量为:", len(movie_cols))

print("用户数值型特征:", user_numeric_cols)
print("电影特征:", movie_cols)
print("电影类型:", list(genre_columns))

电影类型数为: 18
电影数量有: 3706
总的评价数有: 1000209
用户数量为: 6040
选定的用户数值型特征列数量为: 27
选定的电影特征数量为: 23
用户数值型特征: ['mean_rating', 'rating_std', 'rating_count', 'rating_min', 'rating_max', 'rating_strictness', 'rating_variability', 'num_liked_genres', 'favorite_genre_idx', 'Action', 'Adventure', 'Animation', "Children's", 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']
电影特征: ['title', 'movie_mean_rating', 'movie_rating_std', 'movie_rating_count', 'genre_purity', 'Action', 'Adventure', 'Animation', "Children's", 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']
电影类型: ['Action', 'Adventure', 'Animation', "Children's", 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western']


In [301]:
print("movie_features特征列:", movie_features.columns)
print("user_features特征列:", user_features.columns)

movie_features特征列: Index(['movie_id', 'title', 'genres', 'Action', 'Adventure', 'Animation',
       'Children's', 'Comedy', 'Crime', 'Documentary', 'Drama', 'Fantasy',
       'Film-Noir', 'Horror', 'Musical', 'Mystery', 'Romance', 'Sci-Fi',
       'Thriller', 'War', 'Western', 'movie_mean_rating', 'movie_rating_std',
       'movie_rating_count', 'genre_purity'],
      dtype='object')
user_features特征列: Index(['user_id', 'mean_rating', 'rating_std', 'rating_count', 'rating_min',
       'rating_max', 'rating_strictness', 'rating_variability', 'Action',
       'Adventure', 'Animation', 'Children's', 'Comedy', 'Crime',
       'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical',
       'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western',
       'favorite_genre', 'num_liked_genres', 'favorite_genre_idx'],
      dtype='object')


#### 4. 拼接ratings, movies_features_numeric, user_features_numeric

In [306]:
merged_df_1 = pd.merge(ratings, movie_features_numeric, left_on='movie_id', right_on='movie_id', how='inner')
# 再根据user_id将上一步结果和user_features_numeric合并
final_merged_df = pd.merge(merged_df_1, user_features_numeric, left_on='user_id', right_on='user_id', how='inner')


In [309]:
merged_df_1.columns

Index(['user_id', 'movie_id', 'rating', 'timestamp', 'title', 'Action',
       'Adventure', 'Animation', 'Children's', 'Comedy', 'Crime',
       'Documentary', 'Drama', 'Fantasy', 'Film-Noir', 'Horror', 'Musical',
       'Mystery', 'Romance', 'Sci-Fi', 'Thriller', 'War', 'Western',
       'movie_mean_rating', 'movie_rating_std', 'movie_rating_count',
       'genre_purity'],
      dtype='object')

In [311]:
user_features_numeric.head()

Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,num_liked_genres,favorite_genre_idx
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.045028,0.16967,0.226609,0.204426,0.036571,0.0,0.39829,0.093389,0.0,0.0,0.479172,0.0,0.382257,0.297663,0.406715,0.418157,0.0,13,11
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.081828,0.0,0.0,0.117219,0.063696,0.0,0.447701,0.010214,0.010318,0.02085,0.0,0.031933,0.263799,0.252818,0.613924,0.745944,0.51241,14,16
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.249533,0.039769,0.041393,0.431545,0.0,0.0,0.200254,0.062209,0.0,0.099299,0.036615,0.037954,0.196962,0.291509,0.336156,0.195883,0.71377,15,17
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.152565,0.0,0.029868,0.0,0.03076,0.0,0.19023,0.077727,0.0,0.126004,0.0,0.0,0.095535,0.472931,0.380836,0.43579,0.462981,12,14
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.02803,0.012816,0.019472,0.185339,0.08525,0.026617,0.473922,0.0,0.02588,0.088538,0.029117,0.07995,0.325581,0.240121,0.817461,0.629844,0.240616,17,15


### 创建映射字典

In [207]:
# 给定user_id的情况下，可以获取其特征值
user_id_to_feature = {} # key是user_id, value是用户特征 
for _, row in user_features.iterrows():
    try:
        user_id_to_feature[row['user_id']] = row[user_numeric_cols].values.astype(np.float32)
    except KeyError as e:
        print(f"Missing column in user_features: {e}")
        break
print("给定用户id的情况下，该用户的数值型特征数量为:",user_id_to_feature[1].shape) # 有27个特征（不包含user_id, favorite_genre）

给定用户id的情况下，该用户的数值型特征数量为: (27,)


In [208]:
# 给定movie_id的情况下，可以获取其特征值
movie_id_to_feature = {} # key是movie_id，value是电影特征
for _, row in movie_features.iterrows():
    try:
        movie_id_to_feature[row['movie_id']] = row[movie_cols].values
    except KeyError as e:
        print(f"Missing column in movie_features: {e}")
        break
print("给定电影id的情况下，该电影的特征数量为:", movie_id_to_feature[1].shape) # 有23个特征

给定电影id的情况下，该电影的特征数量为: (23,)


### 数据集类

In [209]:
class RatingDataset(Dataset):
    def __init__(self, data, user_id_to_feature, movie_id_to_feature):
        self.data = data
        self.user_id_to_feature = user_id_to_feature # user_id到feature的字典 
        self.movie_id_to_feature = movie_id_to_feature # movie_id到feature的字典
        
        # 获取favorite_genre的映射
        self.genre_to_idx = {genre: idx for idx, genre in enumerate(genre_columns)}
        self.num_genres = len(genre_columns)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        """
        给定idx的情况下，可以得到某一条评分数据，包含user_id, movie_id, rating
        """
        user_id = self.data['user_id'][idx]
        movie_id=self.data['movie_id'][idx]
        rating=self.data['rating'][idx]
        
        # 获取数值特征
        user_feature = self.user_id_to_feature[user_id] # 27个特征 
        movie_feature = self.movie_id_to_feature[movie_id] # 23个特征
        
        # 处理favorite_genre（转换为one-hot）
        favorite_genre = user_features[user_features['user_id'] == user_id]['favorite_genre'].values[0]
        genre_idx = self.genre_to_idx.get(favorite_genre, 0)
        genre_onehot = torch.zeros(self.num_genres)
        genre_onehot[genre_idx] = 1
        
        # 合并所有用户特征
        user_feature = np.concatenate([
            user_feature,  
            genre_onehot.numpy() 
        ])
        
        normalized_rating = (rating - 1) / 5.0 # 最大值为5， 最小值为1
        return (
            torch.FloatTensor(user_feature),
            torch.FloatTensor(movie_feature),
            torch.FloatTensor([normalized_rating])
        )

### 划分训练集和验证集&创建数据加载器

In [212]:
train_data, test_data=train_test_split(data, test_size=0.2, random_state=42)

train_dataset = RatingDataset(train_data, user_id_to_feature, movie_id_to_feature)
test_dataset = RatingDataset(test_data, user_id_to_feature, movie_id_to_feature)

In [213]:
print("训练集样本数量:", len(train_dataset))
print("验证集样本数量:", len(test_dataset))

训练集样本数量: 800167
验证集样本数量: 200042


In [214]:
batch_size = 256
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# print("数据索引范围:", train_dataset.data.min(), train_dataset.data.max())

In [218]:
pd.set_option('display.max_rows', None)
train_dataset.data.min()

user_id                                                        1
movie_id                                                       1
rating                                                         1
timestamp                                              956703932
title          [2, 3296, 3258, 3258, 3258, 3258, 3258, 3258, ...
genres                                                    Action
Action                                                         0
Adventure                                                      0
Animation                                                      0
Children's                                                     0
Comedy                                                         0
Crime                                                          0
Documentary                                                    0
Drama                                                          0
Fantasy                                                        0
Film-Noir                

### 模型架构

In [102]:
class DualTowerModel(nn.Module):
    def __init__(self, user_feature_dim, movie_feature_dim, embedding_dim=64):
        super(DualTowerModel, self).__init__()
        
        # 用户塔
        self.user_tower = nn.Sequential(
            nn.Linear(user_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim),
            nn.ReLU()
        )
        
        # 电影塔
        self.movie_tower = nn.Sequential(
            nn.Linear(movie_feature_dim, 128),
            nn.ReLU(),
            nn.Linear(128, embedding_dim),
            nn.ReLU()
        )
        
        # 评分预测头
        self.rating_head = nn.Sequential(
            nn.Linear(embedding_dim * 2, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()  # 输出在0-1之间
        )
    
    def forward(self, user_features, movie_features):
        user_embedding = self.user_tower(user_features)
        movie_embedding = self.movie_tower(movie_features)
        combined = torch.cat([user_embedding, movie_embedding], dim=1)
        return self.rating_head(combined).squeeze()

### 初始化模型

In [126]:
sample_user, sample_movie, _ = train_dataset[1] # 用户特征，电影特征，评分
user_feature_dim = sample_user.shape[0] # 用户特征情况
movie_feature_dim = sample_movie.shape[0] # 电影特征情况
print("user_feature_dim:", user_feature_dim)
print("movie_feature_dim:", movie_feature_dim)
model = DualTowerModel(user_feature_dim, movie_feature_dim)

user_feature_dim: 135
movie_feature_dim: 22


In [127]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [128]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

num_epochs = 20
for epoch in range(num_epochs):
    model.train()
    train_loss = 0.0
    
    for user_features, movie_features, ratings in train_loader:
        user_features = user_features.to(device)
        movie_features = movie_features.to(device)
        ratings = ratings.to(device)
        
        optimizer.zero_grad()
        outputs = model(user_features, movie_features)
        loss = criterion(outputs, ratings)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item() * user_features.size(0)
    
    train_loss = train_loss / len(train_loader.dataset)


    # 测试集评估
    model.eval()
    test_loss = 0.0
    with torch.no_grad():
        for user_features, movie_features, ratings in test_loader:
            user_features = user_features.to(device)
            movie_features = movie_features.to(device)
            ratings = ratings.to(device)
            
            outputs = model(user_features, movie_features)
            loss = criterion(outputs, ratings)
            test_loss += loss.item() * user_features.size(0)
    
    test_loss = test_loss / len(test_loader.dataset)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')

KeyError: 255728