In [2]:
import pandas as pd

In [3]:
movies_df=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/movies.csv')
ratings_df=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/ratings.csv')
users_df=pd.read_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/ml-1m/users.csv')

## 查看数据基本情况

In [4]:
movies_df.head()

Unnamed: 0,movie_id,title,genres
0,1,Toy Story (1995),Animation|Children's|Comedy
1,2,Jumanji (1995),Adventure|Children's|Fantasy
2,3,Grumpier Old Men (1995),Comedy|Romance
3,4,Waiting to Exhale (1995),Comedy|Drama
4,5,Father of the Bride Part II (1995),Comedy


In [5]:
ratings_df.head()

Unnamed: 0,user_id,movie_id,rating,timestamp
0,1,1193,5,978300760
1,1,661,3,978302109
2,1,914,3,978301968
3,1,3408,4,978300275
4,1,2355,5,978824291


In [6]:
users_df.head()

Unnamed: 0,user_id,gender,age,occupation,zip_code
0,1,F,1,10,48067
1,2,M,56,16,70072
2,3,M,25,15,55117
3,4,M,45,7,2460
4,5,M,25,20,55455


In [7]:
print("用户样本数:",len(users_df))
print("电影样本数:",len(movies_df))
print("评分样本数:", len(ratings_df))

用户样本数: 6040
电影样本数: 3883
评分样本数: 1000209


### 数据预处理

In [8]:
genres=movies_df['genres'].str.get_dummies(sep='|')
movies=pd.concat([movies_df, genres], axis=1)

# 合并评分和电影数据
data=pd.merge(ratings_df, movies, on='movie_id')

In [9]:
genres

Unnamed: 0,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0
2,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0
3,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3878,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
3879,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
3880,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
3881,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0


In [10]:
data.head()

Unnamed: 0,user_id,movie_id,rating,timestamp,title,genres,Action,Adventure,Animation,Children's,...,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western
0,1,1193,5,978300760,One Flew Over the Cuckoo's Nest (1975),Drama,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
1,1,661,3,978302109,James and the Giant Peach (1996),Animation|Children's|Musical,0,0,1,1,...,0,0,0,1,0,0,0,0,0,0
2,1,914,3,978301968,My Fair Lady (1964),Musical|Romance,0,0,0,0,...,0,0,0,1,0,1,0,0,0,0
3,1,3408,4,978300275,Erin Brockovich (2000),Drama,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
4,1,2355,5,978824291,"Bug's Life, A (1998)",Animation|Children's|Comedy,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0


### 用户评分统计特征

In [11]:
user_stats=data.groupby('user_id')['rating'].agg([
    ('mean_rating', 'mean'), # 平均评分
    ('rating_std', 'std'), # 评分标准差
    ('rating_count','count'), # 评分次数
    ('rating_min','min'), # 最小评分
    ('rating_max','max') # 最大评分
]).reset_index()

# 填充可能存在的NaN值
user_stats['rating_std']=user_stats['rating_std'].fillna(0)

# 计算用户评分严格程度
global_mean_rating=data['rating'].mean()
user_stats['rating_strictness']=global_mean_rating-user_stats['mean_rating']

# 计算用户评分波动程度
user_stats['rating_variability']=user_stats['rating_std']/user_stats['mean_rating']

# 查看结果
print(user_stats.head())

   user_id  mean_rating  rating_std  rating_count  rating_min  rating_max  \
0        1     4.188679    0.680967            53           3           5   
1        2     3.713178    1.001513           129           1           5   
2        3     3.901961    0.984985            51           1           5   
3        4     4.190476    1.077917            21           1           5   
4        5     3.146465    1.132699           198           1           5   

   rating_strictness  rating_variability  
0          -0.607115            0.162573  
1          -0.131614            0.269719  
2          -0.320396            0.252433  
3          -0.608912            0.257230  
4           0.435100            0.359991  


### 用户电影类型偏好特征

In [12]:
# 获取所有电影类型列
genre_columns=[col for col in data.columns if col not in ['user_id', 'movie_id', 'rating', 'timestamp','title', 'genres']]


# 计算用户对每种类型的评分次数和平均评分（用户ID、电影类型以及评分次数）
user_genre_stats=data.groupby('user_id')[genre_columns].sum().reset_index()

# 计算用户对每种类型的偏好程度（按行处理，将每一个用户对某一类型的评分，除以该用户对所有类型评分的总和）
for genre in genre_columns:
    user_genre_stats[f'{genre}_favorite_degree']=user_genre_stats[genre]/user_genre_stats[genre_columns].sum(axis=1)

for genre in genre_columns:
    user_genre_stats[f'{genre}_rating_cnt']=user_genre_stats[genre]

# 计算用户最喜欢的类型（返回最大值所在的索引），axis=1表示按照行操作
user_genre_stats['favorite_genre']=user_genre_stats[genre_columns].idxmax(axis=1)

# 计算用户喜欢的类型数量（评分过的类型数）（得到该行中评分大于0的类型数量）
user_genre_stats['num_liked_genres']=(user_genre_stats[genre_columns]>0).sum(axis=1)

user_genre_stats.drop(columns=genre_columns,inplace=True)

# 合并所有用户特征
user_features=pd.merge(user_stats, user_genre_stats, on='user_id')


In [13]:
pd.set_option('display.max_columns', None)
user_features.head()

Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action_favorite_degree,Adventure_favorite_degree,Animation_favorite_degree,Children's_favorite_degree,Comedy_favorite_degree,Crime_favorite_degree,Documentary_favorite_degree,Drama_favorite_degree,Fantasy_favorite_degree,Film-Noir_favorite_degree,Horror_favorite_degree,Musical_favorite_degree,Mystery_favorite_degree,Romance_favorite_degree,Sci-Fi_favorite_degree,Thriller_favorite_degree,War_favorite_degree,Western_favorite_degree,Action_rating_cnt,Adventure_rating_cnt,Animation_rating_cnt,Children's_rating_cnt,Comedy_rating_cnt,Crime_rating_cnt,Documentary_rating_cnt,Drama_rating_cnt,Fantasy_rating_cnt,Film-Noir_rating_cnt,Horror_rating_cnt,Musical_rating_cnt,Mystery_rating_cnt,Romance_rating_cnt,Sci-Fi_rating_cnt,Thriller_rating_cnt,War_rating_cnt,Western_rating_cnt,favorite_genre,num_liked_genres
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.065972,0.0,0.0,0.086806,0.041667,0.0,0.274306,0.003472,0.003472,0.006944,0.0,0.010417,0.083333,0.059028,0.107639,0.052083,0.010417,56,19,0,0,25,12,0,79,1,1,2,0,3,24,17,31,15,3,Drama,14
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.203252,0.02439,0.02439,0.243902,0.0,0.0,0.065041,0.01626,0.0,0.02439,0.00813,0.00813,0.04065,0.04878,0.04065,0.01626,0.04878,23,25,3,3,30,0,0,8,2,0,3,1,1,5,6,5,2,6,Comedy,15
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.103448,0.0,0.017241,0.0,0.017241,0.0,0.103448,0.034483,0.0,0.051724,0.0,0.0,0.034483,0.155172,0.068966,0.051724,0.034483,19,6,0,1,0,1,0,6,2,0,3,0,0,2,9,4,3,2,Action,12
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.025568,0.011364,0.017045,0.159091,0.059659,0.017045,0.295455,0.0,0.008523,0.028409,0.008523,0.022727,0.085227,0.042614,0.110795,0.017045,0.002841,31,9,4,6,56,21,6,104,0,3,10,3,8,30,15,39,6,1,Drama,17


### 电影特征工程

In [14]:
# 计算电影的基本评分统计
movie_stats=data.groupby('movie_id')['rating'].agg([
    ('movie_mean_rating', 'mean'),
    ('movie_rating_std', 'std'),
    ('movie_rating_count', 'count')
])

print("movie_stats:")
print(movie_stats)

# 合并电影原始信息
movie_features=pd.merge(movies, movie_stats, on='movie_id')
print("movie_features:")
print(movie_features.head())

movie_stats:
          movie_mean_rating  movie_rating_std  movie_rating_count
movie_id                                                         
1                  4.146846          0.852349                2077
2                  3.201141          0.983172                 701
3                  3.016736          1.071712                 478
4                  2.729412          1.013381                 170
5                  3.006757          1.025086                 296
...                     ...               ...                 ...
3948               3.635731          1.014196                 862
3949               4.115132          1.009804                 304
3950               3.666667          1.046107                  54
3951               3.900000          1.057331                  40
3952               3.780928          0.935074                 388

[3706 rows x 3 columns]
movie_features:
   movie_id                               title                        genres  \
0      

In [15]:
# 计算电影类型纯度（类型数量越少，纯度越高）
movie_features['genre_purity']=1/movie_features[genre_columns].sum(axis=1)
print(movie_features.head())

   movie_id                               title                        genres  \
0         1                    Toy Story (1995)   Animation|Children's|Comedy   
1         2                      Jumanji (1995)  Adventure|Children's|Fantasy   
2         3             Grumpier Old Men (1995)                Comedy|Romance   
3         4            Waiting to Exhale (1995)                  Comedy|Drama   
4         5  Father of the Bride Part II (1995)                        Comedy   

   Action  Adventure  Animation  Children's  Comedy  Crime  Documentary  \
0       0          0          1           1       1      0            0   
1       0          1          0           1       0      0            0   
2       0          0          0           0       1      0            0   
3       0          0          0           0       1      0            0   
4       0          0          0           0       1      0            0   

   Drama  Fantasy  Film-Noir  Horror  Musical  Mystery  Romanc

In [16]:
print("电影维度:", movie_features.shape)
print("用户维度:", user_features.shape)

电影维度: (3706, 25)
用户维度: (6040, 46)


### 保存特征数据

In [17]:
user_features.to_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/features/user_features.csv', index=False)
movie_features.to_csv('/Users/bytedance/Desktop/MovieLens-Recommendation-System/data/features/movie_features.csv', index=False)

In [18]:
pd.set_option('display.max_columns', None)
user_features.head()

Unnamed: 0,user_id,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action_favorite_degree,Adventure_favorite_degree,Animation_favorite_degree,Children's_favorite_degree,Comedy_favorite_degree,Crime_favorite_degree,Documentary_favorite_degree,Drama_favorite_degree,Fantasy_favorite_degree,Film-Noir_favorite_degree,Horror_favorite_degree,Musical_favorite_degree,Mystery_favorite_degree,Romance_favorite_degree,Sci-Fi_favorite_degree,Thriller_favorite_degree,War_favorite_degree,Western_favorite_degree,Action_rating_cnt,Adventure_rating_cnt,Animation_rating_cnt,Children's_rating_cnt,Comedy_rating_cnt,Crime_rating_cnt,Documentary_rating_cnt,Drama_rating_cnt,Fantasy_rating_cnt,Film-Noir_rating_cnt,Horror_rating_cnt,Musical_rating_cnt,Mystery_rating_cnt,Romance_rating_cnt,Sci-Fi_rating_cnt,Thriller_rating_cnt,War_rating_cnt,Western_rating_cnt,favorite_genre,num_liked_genres
0,1,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13
1,2,3.713178,1.001513,129,1,5,-0.131614,0.269719,0.194444,0.065972,0.0,0.0,0.086806,0.041667,0.0,0.274306,0.003472,0.003472,0.006944,0.0,0.010417,0.083333,0.059028,0.107639,0.052083,0.010417,56,19,0,0,25,12,0,79,1,1,2,0,3,24,17,31,15,3,Drama,14
2,3,3.901961,0.984985,51,1,5,-0.320396,0.252433,0.186992,0.203252,0.02439,0.02439,0.243902,0.0,0.0,0.065041,0.01626,0.0,0.02439,0.00813,0.00813,0.04065,0.04878,0.04065,0.01626,0.04878,23,25,3,3,30,0,0,8,2,0,3,1,1,5,6,5,2,6,Comedy,15
3,4,4.190476,1.077917,21,1,5,-0.608912,0.25723,0.327586,0.103448,0.0,0.017241,0.0,0.017241,0.0,0.103448,0.034483,0.0,0.051724,0.0,0.0,0.034483,0.155172,0.068966,0.051724,0.034483,19,6,0,1,0,1,0,6,2,0,3,0,0,2,9,4,3,2,Action,12
4,5,3.146465,1.132699,198,1,5,0.4351,0.359991,0.088068,0.025568,0.011364,0.017045,0.159091,0.059659,0.017045,0.295455,0.0,0.008523,0.028409,0.008523,0.022727,0.085227,0.042614,0.110795,0.017045,0.002841,31,9,4,6,56,21,6,104,0,3,10,3,8,30,15,39,6,1,Drama,17


In [19]:
movie_features.head()

Unnamed: 0,movie_id,title,genres,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,movie_mean_rating,movie_rating_std,movie_rating_count,genre_purity
0,1,Toy Story (1995),Animation|Children's|Comedy,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,4.146846,0.852349,2077,0.333333
1,2,Jumanji (1995),Adventure|Children's|Fantasy,0,1,0,1,0,0,0,0,1,0,0,0,0,0,0,0,0,0,3.201141,0.983172,701,0.333333
2,3,Grumpier Old Men (1995),Comedy|Romance,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0,0,0,0,3.016736,1.071712,478,0.5
3,4,Waiting to Exhale (1995),Comedy|Drama,0,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0,2.729412,1.013381,170,0.5
4,5,Father of the Bride Part II (1995),Comedy,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0,3.006757,1.025086,296,1.0


In [20]:
interactions = data[['user_id', 'movie_id', 'rating']]
full_data = pd.merge(interactions, user_features, on='user_id')
full_data = pd.merge(full_data, movie_features, on='movie_id')

In [21]:
full_data.head()

Unnamed: 0,user_id,movie_id,rating,mean_rating,rating_std,rating_count,rating_min,rating_max,rating_strictness,rating_variability,Action_favorite_degree,Adventure_favorite_degree,Animation_favorite_degree,Children's_favorite_degree,Comedy_favorite_degree,Crime_favorite_degree,Documentary_favorite_degree,Drama_favorite_degree,Fantasy_favorite_degree,Film-Noir_favorite_degree,Horror_favorite_degree,Musical_favorite_degree,Mystery_favorite_degree,Romance_favorite_degree,Sci-Fi_favorite_degree,Thriller_favorite_degree,War_favorite_degree,Western_favorite_degree,Action_rating_cnt,Adventure_rating_cnt,Animation_rating_cnt,Children's_rating_cnt,Comedy_rating_cnt,Crime_rating_cnt,Documentary_rating_cnt,Drama_rating_cnt,Fantasy_rating_cnt,Film-Noir_rating_cnt,Horror_rating_cnt,Musical_rating_cnt,Mystery_rating_cnt,Romance_rating_cnt,Sci-Fi_rating_cnt,Thriller_rating_cnt,War_rating_cnt,Western_rating_cnt,favorite_genre,num_liked_genres,title,genres,Action,Adventure,Animation,Children's,Comedy,Crime,Documentary,Drama,Fantasy,Film-Noir,Horror,Musical,Mystery,Romance,Sci-Fi,Thriller,War,Western,movie_mean_rating,movie_rating_std,movie_rating_count,genre_purity
0,1,1193,5,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13,One Flew Over the Cuckoo's Nest (1975),Drama,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,4.390725,0.789524,1725,1.0
1,1,661,3,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13,James and the Giant Peach (1996),Animation|Children's|Musical,0,0,1,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,3.464762,1.023202,525,0.333333
2,1,914,3,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13,My Fair Lady (1964),Musical|Romance,0,0,0,0,0,0,0,0,0,0,0,1,0,1,0,0,0,0,4.154088,0.873854,636,0.5
3,1,3408,4,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13,Erin Brockovich (2000),Drama,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0,3.863878,0.895887,1315,1.0
4,1,2355,5,4.188679,0.680967,53,3,5,-0.607115,0.162573,0.043103,0.043103,0.155172,0.172414,0.12069,0.017241,0.0,0.181034,0.025862,0.0,0.0,0.12069,0.0,0.051724,0.025862,0.025862,0.017241,0.0,5,5,18,20,14,2,0,21,3,0,0,14,0,6,3,3,2,0,Drama,13,"Bug's Life, A (1998)",Animation|Children's|Comedy,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,3.854375,0.879284,1703,0.333333


In [22]:
# 定义用户特征列和物品特征列
user_feature_cols = [
    'mean_rating', 'rating_std', 'rating_count', 'rating_min', 'rating_max',
    'rating_strictness', 'rating_variability', 'num_liked_genres'
] + [col for col in user_features.columns if '_favorite_degree' in col]

movie_feature_cols = [
    'movie_mean_rating', 'movie_rating_std', 'movie_rating_count', 'genre_purity'
] + genre_columns

In [23]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler

# 标准化特征
user_scaler = StandardScaler()
movie_scaler = StandardScaler()

user_features[user_feature_cols] = user_scaler.fit_transform(user_features[user_feature_cols])
movie_features[movie_feature_cols] = movie_scaler.fit_transform(movie_features[movie_feature_cols])

In [24]:
# 创建用户和电影的ID映射
user_id_map = {uid: i for i, uid in enumerate(user_features['user_id'].unique())}
movie_id_map = {mid: i for i, mid in enumerate(movie_features['movie_id'].unique())}

num_users = len(user_id_map)
num_items = len(movie_id_map)
user_features_dim = len(user_feature_cols)
item_features_dim = len(movie_feature_cols)

In [25]:
from sklearn.model_selection import train_test_split

# 准备训练数据
def prepare_data(df):
    user_ids = df['user_id'].map(user_id_map).values
    item_ids = df['movie_id'].map(movie_id_map).values
    
    user_feats = df[user_feature_cols].values
    item_feats = df[movie_feature_cols].values
    
    return (user_ids, user_feats, item_ids, item_feats), df['rating'].values

# 分割训练集和测试集
train_df, test_df = train_test_split(full_data, test_size=0.2, random_state=42)
train_data, train_labels = prepare_data(train_df)
test_data, test_labels = prepare_data(test_df)

In [26]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Embedding, Dense, Concatenate, Flatten, Dot, BatchNormalization, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.regularizers import l2

def build_two_tower_model(num_users, num_items, user_features_dim, item_features_dim, embedding_dim=64):
    """构建双塔模型，充分利用用户和物品特征"""
    
    # 用户塔
    user_id_input = Input(shape=(1,), name='user_id')
    user_features_input = Input(shape=(user_features_dim,), name='user_features')
    
    # 用户ID嵌入层
    user_embedding = Embedding(
        input_dim=num_users, 
        output_dim=embedding_dim, 
        embeddings_regularizer=l2(1e-6)
    )(user_id_input)
    user_embedding = Flatten()(user_embedding)
    
    # 拼接用户ID嵌入和其他用户特征
    user_tower = Concatenate()([user_embedding, user_features_input])
    user_tower = Dense(256, activation='relu')(user_tower)
    user_tower = BatchNormalization()(user_tower)
    user_tower = Dropout(0.3)(user_tower)
    user_tower = Dense(128, activation='relu')(user_tower)
    user_tower = Dense(embedding_dim, activation=None)(user_tower)  # 最终用户向量
    
    # 物品塔
    item_id_input = Input(shape=(1,), name='item_id')
    item_features_input = Input(shape=(item_features_dim,), name='item_features')
    
    # 物品ID嵌入层
    item_embedding = Embedding(
        input_dim=num_items, 
        output_dim=embedding_dim, 
        embeddings_regularizer=l2(1e-6)
    )(item_id_input)
    item_embedding = Flatten()(item_embedding)
    
    # 拼接物品ID嵌入和其他物品特征
    item_tower = Concatenate()([item_embedding, item_features_input])
    item_tower = Dense(256, activation='relu')(item_tower)
    item_tower = BatchNormalization()(item_tower)
    item_tower = Dropout(0.3)(item_tower)
    item_tower = Dense(128, activation='relu')(item_tower)
    item_tower = Dense(embedding_dim, activation=None)(item_tower)  # 最终物品向量
    
    # 计算余弦相似度作为预测分数
    output = Dot(axes=1, normalize=True)([user_tower, item_tower])
    
    # 构建模型
    model = Model(
        inputs=[user_id_input, user_features_input, item_id_input, item_features_input], 
        outputs=output
    )
    
    return model

# 构建模型
model = build_two_tower_model(num_users, num_items, user_features_dim, item_features_dim, embedding_dim=64)
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    loss='mse',
    metrics=['mae', tf.keras.metrics.RootMeanSquaredError()]
)
model.summary()

In [29]:
# 训练模型
history = model.fit(
    [train_data[0], train_data[1], train_data[2], train_data[3]],
    train_labels,
    batch_size=1024,
    epochs=20,
    validation_data=([test_data[0], test_data[1], test_data[2], test_data[3]], test_labels),
    callbacks=[
        tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True),
        tf.keras.callbacks.ReduceLROnPlateau(patience=2)
    ]
)

# 评估模型
test_loss, test_mae, test_rmse = model.evaluate(
    [test_data[0], test_data[1], test_data[2], test_data[3]], test_labels
)
print(f"\nTest Loss: {test_loss:.4f}, Test MAE: {test_mae:.4f}, Test RMSE: {test_rmse:.4f}")

Epoch 1/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - loss: nan - mae: nan - root_mean_squared_error: nan - val_loss: nan - val_mae: nan - val_root_mean_squared_error: nan - learning_rate: 1.0000e-04
Epoch 2/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - loss: nan - mae: nan - root_mean_squared_error: nan - val_loss: nan - val_mae: nan - val_root_mean_squared_error: nan - learning_rate: 1.0000e-04
Epoch 3/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 7ms/step - loss: nan - mae: nan - root_mean_squared_error: nan - val_loss: nan - val_mae: nan - val_root_mean_squared_error: nan - learning_rate: 1.0000e-05
Epoch 4/20
[1m782/782[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 8ms/step - loss: nan - mae: nan - root_mean_squared_error: nan - val_loss: nan - val_mae: nan - val_root_mean_squared_error: nan - learning_rate: 1.0000e-05
[1m6252/6252[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s