In [1]:
import pandas as pd

# 读取原始用户数据
def user_data_read():
    file_path = 'ml-100k/'
    
    # 读取user数据
    user_names = ['id', 'age', 'gender', 'occupation', 'zip_code']
    user_info = pd.read_csv(file_path + 'u.user', sep='|', names=user_names)

    return user_info

In [2]:
# 预处理用户数据
def user_data_preprocess(user_info):
    # 将gender列的M和F分别映射到0和1
    sex_mapDict = {'M': 0, 'F': 1}
    user_info['gender'] = user_info['gender'].map(sex_mapDict)
    return user_info

In [3]:
from sklearn.preprocessing import OneHotEncoder

user_info = user_data_read()
user_info = user_data_preprocess(user_info)
# print(user_info)

occupation = user_info[['occupation']]
# 将occupation一列转换为one-hot向量
enc = OneHotEncoder()
enc.fit(occupation)

occupation_one_hot = enc.transform(occupation).toarray()
occupation_one_hot = pd.DataFrame(occupation_one_hot)

print(occupation_one_hot)

print(type(occupation_one_hot))

      0    1    2    3    4    5    6    7    8    9   ...   11   12   13  \
0    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
1    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
2    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
3    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
4    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
5    0.0  0.0  0.0  0.0  0.0  0.0  1.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
6    1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
7    1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
8    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
9    0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  ...  0.0  0.0  0.0   
10   0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
11   0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   

In [5]:
print(type(user_info[['id']]))

<class 'pandas.core.frame.DataFrame'>


In [6]:
# 将各列数据拼接在一起，每条数据是一个向量形式
user_vec = pd.concat([user_info['id'], user_info['age'], user_info['gender'], occupation_one_hot], axis=1)
user_vec.to_csv('preprocessed_data/user.csv')
print(user_vec)

      id  age  gender    0    1    2    3    4    5    6  ...   11   12   13  \
0      1   24       0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
1      2   53       1  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
2      3   23       0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
3      4   24       0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
4      5   33       1  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
5      6   42       0  0.0  0.0  0.0  0.0  0.0  0.0  1.0  ...  0.0  0.0  0.0   
6      7   57       0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
7      8   36       0  1.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
8      9   29       0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
9     10   53       0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  0.0   
10    11   39       1  0.0  0.0  0.0  0.0  0.0  0.0  0.0  ...  0.0  0.0  1.0   
11    12   28       1  0.0  0.0  0.0  0.

In [7]:
# 读取原始电影数据
def movie_data_read():
    file_path = 'ml-100k/'
   
    movie_names = ['id', 'title', 'release_date', 'video_release_date', 'URL', 'unkonwn', 'action', 
                   'adventure', 'animation', 'childrens', 'comedy', 'crime', 'documentary', 'drama',
                   'fantasy', 'film-noir', 'horror', 'musical', 'mystery', 'romance', 'sci-fi', 'thriller',
                   'war', 'western']
    movie_info = pd.read_csv(file_path + 'u.item', sep='|', names=movie_names, encoding='latin-1')
    return movie_info

In [8]:
movie_info = movie_data_read()
print(movie_info)

        id                                              title release_date  \
0        1                                   Toy Story (1995)  01-Jan-1995   
1        2                                   GoldenEye (1995)  01-Jan-1995   
2        3                                  Four Rooms (1995)  01-Jan-1995   
3        4                                  Get Shorty (1995)  01-Jan-1995   
4        5                                     Copycat (1995)  01-Jan-1995   
5        6  Shanghai Triad (Yao a yao yao dao waipo qiao) ...  01-Jan-1995   
6        7                              Twelve Monkeys (1995)  01-Jan-1995   
7        8                                        Babe (1995)  01-Jan-1995   
8        9                            Dead Man Walking (1995)  01-Jan-1995   
9       10                                 Richard III (1995)  22-Jan-1996   
10      11                               Seven (Se7en) (1995)  01-Jan-1995   
11      12                         Usual Suspects, The (1995)  1

In [9]:
# 电影数据预处理
import datetime
from datetime import datetime

movie_info = movie_data_read()
movie_info.dropna(axis=1, how='all', inplace=True)
movie_info.dropna(inplace=True)

date = movie_info['release_date']
date = date.apply(lambda x:datetime.strptime(str(x), '%d-%b-%Y'))

movie_info['release_date'] = date.map(lambda x: x.year)
movie_info

Unnamed: 0,id,title,release_date,URL,unkonwn,action,adventure,animation,childrens,comedy,...,fantasy,film-noir,horror,musical,mystery,romance,sci-fi,thriller,war,western
0,1,Toy Story (1995),1995,http://us.imdb.com/M/title-exact?Toy%20Story%2...,0,0,0,1,1,1,...,0,0,0,0,0,0,0,0,0,0
1,2,GoldenEye (1995),1995,http://us.imdb.com/M/title-exact?GoldenEye%20(...,0,1,1,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,3,Four Rooms (1995),1995,http://us.imdb.com/M/title-exact?Four%20Rooms%...,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,4,Get Shorty (1995),1995,http://us.imdb.com/M/title-exact?Get%20Shorty%...,0,1,0,0,0,1,...,0,0,0,0,0,0,0,0,0,0
4,5,Copycat (1995),1995,http://us.imdb.com/M/title-exact?Copycat%20(1995),0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
5,6,Shanghai Triad (Yao a yao yao dao waipo qiao) ...,1995,http://us.imdb.com/Title?Yao+a+yao+yao+dao+wai...,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
6,7,Twelve Monkeys (1995),1995,http://us.imdb.com/M/title-exact?Twelve%20Monk...,0,0,0,0,0,0,...,0,0,0,0,0,0,1,0,0,0
7,8,Babe (1995),1995,http://us.imdb.com/M/title-exact?Babe%20(1995),0,0,0,0,1,1,...,0,0,0,0,0,0,0,0,0,0
8,9,Dead Man Walking (1995),1995,http://us.imdb.com/M/title-exact?Dead%20Man%20...,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,0,0
9,10,Richard III (1995),1996,http://us.imdb.com/M/title-exact?Richard%20III...,0,0,0,0,0,0,...,0,0,0,0,0,0,0,0,1,0


In [10]:
movie_info.loc[:, 'unkonwn': 'western']

Unnamed: 0,unkonwn,action,adventure,animation,childrens,comedy,crime,documentary,drama,fantasy,film-noir,horror,musical,mystery,romance,sci-fi,thriller,war,western
0,0,0,0,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1,0,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,0
3,0,1,0,0,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0
4,0,0,0,0,0,0,1,0,1,0,0,0,0,0,0,0,1,0,0
5,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
6,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,1,0,0,0
7,0,0,0,0,1,1,0,0,1,0,0,0,0,0,0,0,0,0,0
8,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,0,0
9,0,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0,0,1,0


In [11]:
movie_vec = pd.concat([movie_info['id'], movie_info['release_date'], movie_info.loc[:, 'unkonwn': 'western']], axis=1)

In [12]:
movie_vec.to_csv('preprocessed_data/item.csv')

In [135]:
# 读取用户向量数据
def user_dict_load():
    data_dir = 'preprocessed_data/user.csv'
    data = pd.read_csv(data_dir, sep=',')
    return data

In [136]:
data = user_dict_load()
data.head(5)
data.drop(labels=['Unnamed: 0'],axis=1, inplace=True)
data.head(5)

Unnamed: 0,id,age,gender,0,1,2,3,4,5,6,...,11,12,13,14,15,16,17,18,19,20
0,1,24,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
1,2,53,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,3,23,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
3,4,24,0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
4,5,33,1,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


In [143]:
user_df = data.iloc[:, 1:]
user_arr= np.array(user_df)
user_vec = user_arr.tolist()
len(user_vec)

943

In [145]:
len(user_vec[0])

23

In [146]:
user_dict = {}
for line in data.itertuples():
    user_dict[line.id] = user_vec[line.Index]

In [147]:
user_dict

{1: [24.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0],
 2: [53.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 3: [23.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0],
 4: [24.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0],
 5: [33.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 6: [42.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0],
 7: 

In [148]:
# 读取电影向量数据
def movie_dict_load():
    data_dir = 'preprocessed_data/item.csv'
    data = pd.read_csv(data_dir, sep=',')
    return data

In [149]:
data = movie_dict_load()
data.drop(labels=['Unnamed: 0'], axis=1, inplace=True)
data.head(5)

Unnamed: 0,id,release_date,unkonwn,action,adventure,animation,childrens,comedy,crime,documentary,...,fantasy,film-noir,horror,musical,mystery,romance,sci-fi,thriller,war,western
0,1,1995,0,0,0,1,1,1,0,0,...,0,0,0,0,0,0,0,0,0,0
1,2,1995,0,1,1,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
2,3,1995,0,0,0,0,0,0,0,0,...,0,0,0,0,0,0,0,1,0,0
3,4,1995,0,1,0,0,0,1,0,0,...,0,0,0,0,0,0,0,0,0,0
4,5,1995,0,0,0,0,0,0,1,0,...,0,0,0,0,0,0,0,1,0,0


In [150]:
movie_df = data.iloc[:, 1:]
movie_arr = np.array(movie_df)
movie_vec = movie_arr.tolist()
len(movie_vec)

1679

In [151]:
len(movie_vec[0])

20

In [152]:
movie_dict = {}
for line in data.itertuples():
    movie_dict[line.id] = movie_vec[line.Index]

In [153]:
movie_dict

{1: [1995, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 2: [1995, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 3: [1995, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 4: [1995, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 5: [1995, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 6: [1995, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 7: [1995, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
 8: [1995, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 9: [1995, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 10: [1996, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],
 11: [1995, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 12: [1995, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
 13: [1995, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
 14: [1994, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
 15: [1996, 0, 0, 0, 0, 0, 0,

In [154]:
import os
# 加载rating数据
def rating_data_read():
    file_dir = 'ml-100k/'
    names = ['user_id', 'item_id', 'rating', 'timestamp']
    data = pd.read_csv(os.path.join(file_dir, 'u.data'), sep='\t', names=names)
    return data

In [155]:
import numpy as np
# 随机划分训练集和测试集
def random_split_data(data, test_ratio=0.3):
    mask = [True if x == 1 else False for x in np.random.uniform(
    0, 1, (len(data))) < 1 - test_ratio]
    neg_mask = [not x for x in mask]
    train_data, test_data = data[mask], data[neg_mask]
    return train_data, test_data

In [156]:
# 将训练集和测试集转换为列表
def data_transform(data):
    users, items, scores = [], [], []
    for line in data.itertuples():
        user_index, item_index, score = int(line[1]), int(line[2]), int(line[3])
        users.append(user_index)
        items.append(item_index)
        scores.append(score)
    return users, items, scores

In [157]:
import torch
from torch.utils.data import Dataset as Dataset

class MLDataset(Dataset):
    def __init__(self, user, item, score):
        self.user = user
        self.item = item
        self.score = score
    
    def __getitem__(self, index):
        return self.user[index], self.item[index], self.score[index]
    
    def __len__(self):
        return len(self.user)

In [158]:
from torch.utils.data import DataLoader as DataLoader
# 封装到Dataset和DataLoader中
rating_data = rating_data_read()
train_data, test_data = random_split_data(rating_data)
train_user, train_item, train_score = data_transform(train_data)
test_user, test_item, test_score = data_transform(test_data)

train_Dataset = MLDataset(train_user, train_item, train_score)
train_iter = DataLoader(dataset=train_Dataset, batch_size=32, shuffle=True)

for data in train_iter:
    print(data)

[tensor([797, 393,  13, 732, 468, 899, 705, 715, 885,  33, 200, 669, 334, 597,
        435, 780, 701, 483, 435,  82, 174, 391, 267, 333, 899, 124, 647, 452,
        682, 524, 840, 222]), tensor([181,  17, 761, 269,  42, 515, 377, 100, 420, 343, 515, 902,  56, 477,
        571, 497, 304, 197, 117, 169, 340,   8, 824, 435, 238, 616, 631, 474,
        357,  94, 653, 455]), tensor([5, 1, 4, 5, 4, 3, 4, 2, 4, 4, 5, 2, 4, 5, 2, 2, 4, 3, 3, 4, 5, 3, 4, 4,
        2, 4, 4, 3, 3, 2, 5, 3])]
[tensor([659, 450, 497, 833, 268, 622, 194, 933, 642, 627, 343, 234, 533, 426,
        419, 846, 149, 805, 647, 330, 189, 210, 230, 458, 835, 201, 435, 201,
        876, 468, 383,  43]), tensor([ 393,  794,   63,  157,   40,    8,  466,   97, 1178,  402,   53,  358,
           1,  608,   14,  172,  302,  420,  294,  443, 1154,  402,  183,  410,
         157, 1008,  240,  895,  288,  603,  663,  820]), tensor([3, 5, 3, 2, 3, 4, 4, 2, 3, 3, 5, 1, 4, 4, 5, 4, 4, 4, 3, 4, 3, 5, 3, 1,
        4, 3, 3, 3, 3, 5, 5,

        5, 3, 5, 4, 3, 3, 4, 3])]
[tensor([330, 852, 379, 176, 703,  72, 584, 541, 796, 530, 500,  13, 746, 392,
        204, 537, 374, 689, 892, 276,  54, 141, 222, 345, 593, 721, 417, 269,
         10, 758,   6, 438]), tensor([ 405,  930,  251,  948,  845,  530,  165,  452,  281,  156, 1014,  872,
         523,  197,  170,   87,  742,  222,  227,  685,  272,  407,  183,  955,
         237,  680,  449,  447,  530,  116,  469,  321]), tensor([5, 3, 5, 4, 4, 4, 1, 3, 4, 4, 2, 3, 3, 5, 5, 3, 5, 5, 4, 4, 5, 2, 4, 4,
        4, 3, 3, 3, 4, 5, 5, 5])]
[tensor([299,  59, 899, 634, 474, 911, 747, 721, 490, 141, 463, 530, 880, 270,
        499,  92, 343, 585, 532, 455,  49, 417, 379, 434, 484, 224, 542, 342,
        474, 406, 916, 864]), tensor([ 313,  491,  230,  147,  519,  404,  604,  199,  150,  118,  740,  527,
         283,  535,  430,  307, 1008,  970,  472,  279, 1068, 1215,  522,  546,
         578,  582,  194,   88, 1020,  508,   73,   99]), tensor([3, 4, 4, 2, 4, 3, 5, 4, 5, 5, 4, 4

        3, 3, 4, 3, 3, 5, 3, 5])]
[tensor([153, 313, 544, 409, 727, 215, 560, 692,  16, 483, 916, 884, 660, 880,
        551, 424,  28, 592, 798, 244, 145, 868, 532, 932, 345, 354, 790, 796,
        450, 454, 796, 253]), tensor([  22,  636,  877, 1524,  831,   77,  288,  321,   87,   50,    7,  323,
         739, 1267,  399,  508,  164,  680,  944,   77,  825,  214,   58,  429,
         387,  513,  373,  248,  490,  313,  423,  202]), tensor([2, 4, 2, 4, 3, 3, 4, 3, 4, 5, 4, 2, 2, 4, 3, 3, 4, 1, 4, 4, 4, 3, 4, 5,
        4, 5, 3, 3, 5, 5, 4, 5])]
[tensor([ 35,  31, 629, 288, 325, 868, 472, 707, 931, 125, 130, 219,  21, 406,
        277,  95, 766, 211, 889, 361, 181, 699, 648, 201, 730, 354,  38, 642,
        295,  40, 925, 425]), tensor([ 300,  490,  275,  234,  525,  762,   79,  936,  281,   56, 1088, 1014,
         873,  404,   93,  586,  132,  687,   79,  513,  270, 1643,  748,   10,
         535,  664,  188,  401,  461,  258,  816, 1314]), tensor([5, 4, 5, 4, 5, 4, 5, 4, 3, 1, 2, 3

        3, 2, 5, 5, 5, 3, 4, 4])]
[tensor([216, 663, 828, 577, 889, 262, 276, 181, 537, 407,   7, 593, 707, 864,
         85, 303, 721, 870, 308,  87, 870, 624, 665, 378, 898, 545, 833, 890,
        414,  22, 313, 658]), tensor([ 577,  363,  955, 1336,  178,  582,  603, 1385,  191,  729,  570,  193,
          88,  780,  173,  458,  582,  218,  613,  300,  248,  258,    7,  674,
         315,  167,  515,  403,  272,  211,  837, 1079]), tensor([1, 2, 3, 1, 5, 4, 5, 1, 4, 4, 3, 4, 3, 2, 3, 3, 3, 4, 4, 3, 4, 4, 4, 3,
        5, 3, 3, 1, 5, 3, 4, 2])]
[tensor([361, 805, 233,  64, 465, 659, 463, 477, 411,  95, 663, 377, 109, 545,
        778,  96,  82, 535, 587, 506,  57,  13,  25, 850,  85, 279, 510, 471,
        354, 943, 798, 399]), tensor([ 173,  343,   82,  447,   32,  199,  930,   66,  276,  657,  956,  258,
         295,   89,   98, 1154,   73,  425,  245,  542,  748,  670,  169,  663,
          98,   41,  330,  477,  135,  139,  419,   22]), tensor([5, 5, 4, 4, 3, 4, 1, 5, 3, 5, 4, 4

        3, 4, 5, 4, 5, 2, 3, 4])]
[tensor([752, 596, 384,  42, 506, 512, 935, 917, 479, 519,   7, 117, 185,  80,
        764, 119, 746, 345, 162, 486, 405, 639, 764, 452, 264, 782,  71, 417,
         44, 532, 881, 267]), tensor([ 887,   13,  989,  924,   53,   23,  117,  282,   82,  263,  607, 1047,
         114,   86, 1046,  995,  168, 1048,  710,  269,  469,  210,   89,  423,
         559, 1513,   89,  979,  231, 1168, 1480,    5]), tensor([1, 2, 4, 3, 4, 4, 4, 4, 4, 5, 3, 2, 4, 5, 4, 4, 3, 2, 4, 4, 1, 3, 4, 5,
        5, 2, 5, 3, 2, 4, 2, 3])]
[tensor([123,  49, 290, 846, 805, 709, 397, 851,  21, 184, 631, 534, 276, 620,
        796,  56,  79,  54, 854, 616, 176, 862,  94,  13, 480, 561, 339,  95,
        176,   7, 327, 327]), tensor([ 197,   13,  622,  197,  222,  117,  657,  932,  858,   58,  332,  597,
         578,   35, 1041,   64,  740,  597,  628,  301,  328,  498,   25,  800,
          56,  426,  134, 1221,  508,  238,  144,  875]), tensor([5, 3, 3, 4, 4, 4, 5, 3, 1, 4, 3, 5

        2, 4, 4, 5, 1, 4, 2, 5])]
[tensor([363,  91, 442,  55, 256, 654, 726, 318, 503, 447, 249, 587, 159, 291,
        716, 207, 347, 332, 232, 727,  22, 458, 180,  10, 279, 394, 178, 380,
        395, 121, 174,   7]), tensor([ 181,  127,  684,  254, 1051,  535,  898, 1044,   14,  237,  174, 1265,
         121, 1017,  204,  248,   73,    1,   52,   56,  684,  515,  258,  238,
         234, 1210,  269, 1116,  892, 1194,  843,  483]), tensor([5, 5, 3, 2, 4, 3, 2, 4, 3, 4, 4, 4, 3, 4, 5, 3, 2, 4, 5, 3, 3, 4, 5, 4,
        2, 3, 4, 4, 3, 4, 2, 4])]
[tensor([478, 653,  28,  15, 538, 633, 222,  68,  60,  75, 771, 343, 401, 151,
        344, 828, 195, 601, 483, 788, 303, 416, 387, 663, 456, 793, 747, 856,
        766, 440,  49, 178]), tensor([ 42, 195, 322, 926, 204, 651, 768, 275, 474, 866, 172,  77, 357, 425,
        715, 895, 264, 378, 270, 157, 286, 283, 432, 289, 129, 122,  94, 688,
         22, 971,  57, 354]), tensor([5, 5, 2, 1, 3, 3, 2, 5, 5, 2, 4, 3, 4, 4, 4, 2, 3, 2, 3, 5, 5, 5, 

[tensor([620, 405, 664, 936,  57, 776, 104, 301, 620, 897, 457, 263, 125, 889,
        606,   1, 543, 935, 692, 308,  85, 896, 393,  28, 429,  91, 901, 373,
         48, 606, 268, 234]), tensor([ 123,   26,  805,  898,  204,  441,  475,  411,  323,   33,  208,  250,
         205,  427,  473,  203,  313,  685, 1012,  276,  241,  582,  181,  195,
          47,   22,  257,  168,  603,   28,  527, 1003]), tensor([3, 3, 5, 1, 4, 2, 4, 1, 5, 5, 4, 2, 5, 4, 4, 4, 3, 4, 1, 4, 3, 2, 4, 4,
        4, 5, 4, 5, 4, 4, 4, 2])]
[tensor([921, 344, 818, 151, 497, 474, 896, 334, 766, 452,  64, 496, 450, 388,
        409, 665, 846, 291, 938, 383, 214, 682, 174, 301, 807, 230, 567, 625,
        586, 458, 251, 280]), tensor([ 678,   89, 1105,  181, 1157,  190,  230,  337,  602,  462,  237, 1041,
        1222,  301,  381,  473,  615,  101,    1,  197,  603,  209, 1074,  150,
         199,  693,  479,  190,  550,  208,  172,  174]), tensor([5, 5, 1, 5, 2, 3, 4, 4, 4, 4, 4, 1, 3, 4, 2, 4, 5, 4, 4, 5, 4, 3, 4,

[tensor([786, 527, 555, 144, 889, 483, 916, 663, 610, 201, 279, 641, 357, 365,
        758, 348, 399, 262, 592, 327, 805, 303, 181, 771, 825, 777, 275,  22,
        505, 318, 292, 456]), tensor([ 322,   69,  258,  318,  271,   20,   79,  151,  187,  682,   44,  242,
         744,  321,  248,  118,  225,   86,  129,   47,  856, 1044,  875,   97,
         742,   56,  135,  411,  307,  158,  203,    9]), tensor([3, 4, 3, 5, 3, 2, 3, 3, 4, 3, 1, 5, 5, 5, 4, 4, 3, 3, 5, 4, 4, 3, 3, 1,
        4, 5, 3, 1, 4, 5, 4, 3])]
[tensor([ 17, 864, 768, 326,   7, 864, 883, 387, 889, 486, 655, 698, 856, 618,
        113, 864, 271, 221,  54, 877, 804, 548, 405, 328, 827, 443, 380, 805,
        311,  99, 459, 899]), tensor([ 245,  715,  310,  182,  503,  208,  707,  488,  127,  273, 1071,  855,
         272,  559,  328,  140,  402,  407,  871,  155,  552,  595, 1109,  172,
         269,  948,  100,   89,  357,  201,  252,  717]), tensor([2, 4, 4, 2, 4, 4, 3, 3, 4, 3, 2, 2, 5, 3, 5, 3, 4, 2, 5, 2, 4, 4, 1,

[tensor([271, 886, 281, 854, 198, 588,  72, 487, 181, 882, 459, 347, 479, 582,
        613, 665, 429, 405, 503, 851, 747, 250,  87, 222,  44, 417, 821, 495,
        271, 833, 634, 730]), tensor([ 410,  578,  682,  132,  249,  660,  121,  286,  236,  546,  866,  550,
         295,  547,  435,  845,  415,  552,  744, 1034,  390,  338,  273,  240,
         655,  423,  132, 1444,  179,   76,  932,  685]), tensor([2, 4, 3, 5, 2, 4, 3, 2, 1, 2, 5, 5, 1, 4, 5, 4, 3, 1, 2, 1, 4, 4, 3, 2,
        3, 4, 5, 2, 4, 2, 3, 2])]
[tensor([ 89, 893, 224, 462, 367, 878, 653, 436, 755, 868, 256, 115,  77, 409,
        679, 585, 439, 436,  90, 271,  60, 862, 682, 119, 454, 719, 379, 312,
         72,   6, 644, 860]), tensor([ 216,  471,  977,  271,  876, 1149, 1133,  144,  689,  448,  356, 1067,
           4,  264,  241,  116,  100,  200,  318,  357,  638,  474,  476,  755,
         602,  284,  559,  657,   15,   95,  988,  344]), tensor([5, 4, 2, 1, 3, 4, 2, 5, 3, 2, 3, 4, 3, 1, 3, 3, 3, 3, 5, 5, 5, 5, 1,

In [159]:
import torch
from torch.autograd import Variable

# 定义模型
class Model(torch.nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.linear = torch.nn.Linear(43, 1)
    
    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

In [160]:
model = Model()
# 定义损失函数
criterion = torch.nn.MSELoss(reduction='sum')
# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

In [161]:
for data in train_iter:
    print(data)

[tensor([393, 653, 378, 281, 189, 648, 236, 886, 437, 804, 655, 626, 472, 201,
        937, 567, 181, 773, 932, 385, 116, 224, 156, 704, 503, 886, 276, 407,
        627, 435, 828, 472]), tensor([ 241, 1228,  482,  300,   50, 1110,  135,  212,  602,  250, 1086,  333,
         831,  685,  293,  484,  832,   98, 1149,  652,  323,  724,  357,  304,
         303, 1093,  559,  209,  724,  385,  753,  143]), tensor([4, 2, 4, 4, 5, 3, 2, 2, 3, 4, 3, 1, 5, 3, 4, 4, 1, 4, 4, 5, 3, 3, 4, 2,
        5, 1, 4, 5, 2, 5, 4, 4])]
[tensor([749, 109, 327, 122, 184, 305, 566, 250, 597, 709, 733,  13, 269, 181,
        711, 393, 379, 921, 853, 145, 344, 390, 470, 201, 607, 707, 889, 252,
         89, 452, 779, 181]), tensor([1188,  931,  293,  509,  443,  655,   95,  258,  235,  117,  149,  223,
          42,  823, 1170,  633,  391,  185,  333,  563,  568,  329,  950,  537,
         121,  486,  696,  847,   49,  102,  926,  763]), tensor([3, 2, 3, 4, 3, 4, 2, 4, 4, 4, 4, 5, 5, 2, 3, 2, 4, 3, 4, 3, 5, 3, 3,

[tensor([405, 188, 183, 270, 630, 342, 180, 532, 880, 600, 543, 652, 535, 608,
        254, 104, 409,  13,  37, 605, 363, 938, 251, 523, 743, 144, 236, 942,
        758, 643, 472, 555]), tensor([ 568,   77,   54,   70,  216,   88,  318,  407,  368,   50,  479,  257,
         156,  163,  504,  628, 1558,  825,   11,  187,  469,  248,  520,   42,
         222,   93, 1328,   95,   95,  474,  181,  249]), tensor([4, 4, 2, 5, 5, 1, 5, 2, 1, 4, 4, 2, 2, 1, 3, 4, 5, 1, 4, 5, 2, 1, 5, 3,
        4, 1, 4, 5, 3, 5, 5, 4])]
[tensor([499, 145, 177, 521, 125, 511, 633,   1, 600, 269, 913, 417, 130, 327,
        519, 152, 821, 361, 934, 343, 528,  93, 934, 829, 788, 776, 153, 184,
        660, 713, 798, 773]), tensor([  87,  486,  270, 1016,   85,  271, 1046,  248,  431,  156,  195,  715,
         881,  381,  887,  393, 1060,  213,  506,   72,   77,  235,  474,  190,
         554,  185,  357,  708,  252, 1434, 1034,  675]), tensor([4, 3, 1, 3, 3, 5, 4, 4, 3, 5, 4, 2, 4, 4, 5, 5, 5, 5, 4, 5, 3, 4, 4,

        3, 5, 3, 1, 2, 4, 4, 4])]
[tensor([416, 210, 198, 833, 532, 790, 596, 445, 450, 458,   7, 840, 305, 870,
         10, 933,  94, 312, 378, 834, 387, 724, 835, 291, 127, 885, 236, 318,
        383, 223, 393, 804]), tensor([ 11, 708, 216, 396, 531, 184, 289, 994, 654, 845, 496, 234, 121, 317,
        414,  87,   1, 631,  58, 275,  12, 302, 310, 820, 748, 655, 520, 100,
        488, 908,  58, 932]), tensor([4, 5, 4, 3, 5, 3, 3, 1, 4, 3, 5, 5, 3, 4, 4, 4, 4, 5, 4, 3, 5, 3, 4, 4,
        5, 3, 4, 5, 4, 1, 3, 3])]
[tensor([ 83, 244, 642, 424,  90, 586, 500, 881, 345, 354, 328, 119, 224, 716,
        442, 413,  87, 188, 401, 305, 334, 149, 663, 709, 627, 115, 271, 419,
        290, 415, 435, 807]), tensor([ 845,  215,  609,  100,  153,  358,  164,  192,  215,  709,  258,  458,
         570,  102,   33,  258,    2,   69,   44,    2,  171,  302,  129,  129,
         685,   48,   47,  212,  318,  748,  184, 1274]), tensor([3, 4, 3, 5, 5, 4, 4, 5, 4, 5, 5, 5, 4, 2, 3, 4, 4, 4, 4, 2, 4, 4, 

[tensor([841, 497, 833, 151, 592, 318, 401, 627, 242, 734, 367, 416, 806, 830,
        787, 535, 503,   1, 615, 406, 495, 431,  83, 506, 632, 305, 480, 540,
        890, 158,  49, 342]), tensor([1294,  384,  284,  770,   61,  196,  312,  387,  111,  174,  324, 1503,
         254,  205,  310,  628,  185,  175,   72,  672,   98,  303,  243,  503,
         186,  471,  152,  181,  133,  648,  396,  764]), tensor([5, 2, 1, 4, 4, 3, 3, 2, 4, 4, 5, 4, 3, 5, 5, 4, 5, 5, 2, 2, 5, 4, 3, 4,
        5, 4, 4, 4, 5, 5, 4, 1])]
[tensor([890, 119, 463, 109, 655, 784, 343, 210, 854, 862, 933, 845, 443, 452,
        769, 886, 889, 551, 311, 458,  91,  17, 632, 426, 125, 887, 602, 417,
        341, 263, 826, 573]), tensor([258, 340, 117, 748, 135, 258,  98, 179, 507, 436, 216, 302, 687,  23,
         15, 732, 234, 912, 222, 499,  56,   9, 739, 481, 174,  24, 508, 234,
        335, 886,  53, 513]), tensor([3, 4, 3, 3, 4, 5, 5, 3, 4, 4, 3, 3, 3, 2, 3, 3, 4, 3, 4, 4, 1, 3, 3, 5,
        5, 5, 3, 4, 4, 2, 5,

[tensor([453, 229, 269, 682, 586,  13, 174, 254, 144, 443, 213, 232, 833, 816,
        448, 406, 311, 450, 167, 869, 389,  71, 768, 181, 617, 297, 363, 505,
        279, 786, 315, 933]), tensor([  94,  311,  139,  209,  467,  601,  323,  625,   54,  309,  117,  172,
         198, 1025,  303,  655,  127,  619,  290,  411,  613,  285,  269,  598,
         145,  182,  906,  228,   88,  381,  156,    7]), tensor([4, 5, 1, 3, 4, 4, 1, 3, 2, 5, 4, 4, 4, 4, 4, 3, 4, 3, 3, 4, 5, 3, 3, 1,
        1, 3, 2, 2, 1, 3, 5, 4])]
[tensor([758, 196, 144, 862, 156, 428,  35, 379, 178, 378, 731,  82, 354, 615,
        889, 536, 109, 224,  57, 283, 116, 524, 311, 690,  94, 935, 286, 214,
        128, 181, 546, 305]), tensor([ 269,  381,   56,  423,   22,  343,  264,  447,  546,  768,  484, 1126,
         692,  708,  279,   79,  627,  313,  109,  709,   50, 1154,  729,   80,
        1110,  597,   72, 1017,  238, 1381,  898,  197]), tensor([4, 4, 4, 4, 3, 2, 2, 4, 3, 4, 3, 4, 2, 2, 2, 4, 5, 5, 4, 5, 3, 1, 4,

        1, 5, 4, 3, 4, 4, 1, 5])]
[tensor([233, 937, 397, 295, 815, 889, 756, 781,  50, 537, 229, 221, 311, 561,
         92, 188, 826, 846,  13, 326, 894, 650, 506,   1, 195, 151, 748, 526,
        239, 293, 580, 121]), tensor([ 249,  303,  171,  162,  222,   86,  742,  258,  544,  488,  312,   33,
          71,  382,  466,   50,  127, 1210,   28,  663,   59,  152,  323,  110,
         508,  222,   48,  283,  198,  328,  358,  298]), tensor([5, 4, 5, 4, 4, 4, 3, 2, 4, 4, 3, 4, 4, 4, 4, 4, 5, 2, 5, 1, 5, 3, 3, 1,
        3, 5, 4, 3, 5, 2, 4, 2])]
[tensor([392, 146, 772, 223, 374, 116, 378, 668, 179, 505, 738, 821, 315, 560,
        427, 454, 293, 907, 456, 530, 842, 605, 268, 404, 180, 561, 265, 788,
        711, 846, 385, 270]), tensor([ 50, 346, 312, 259,  48, 903, 298, 257, 313, 496,  22, 181, 513, 134,
        341,   1, 550,  50, 186, 527, 340, 462,  62, 687, 658, 193, 181, 480,
        582, 464, 135, 872]), tensor([5, 4, 4, 3, 5, 2, 3, 3, 4, 5, 3, 4, 5, 5, 5, 3, 1, 4, 4, 4, 5, 5, 

        3, 4, 5, 4, 3, 4, 4, 1])]
[tensor([217, 474, 328,  26, 427,   7, 514, 588, 213, 197, 201, 811, 278, 660,
        234, 934,  43, 425, 416, 148, 188, 158, 276, 174, 295, 345, 379, 468,
        524, 533,  66, 447]), tensor([ 554,  244,  559,  369,  319,  186,  384,  739,  132,   68,  895,  895,
         173,  358,  513,  204,  254, 1595,  783,  181,  462,  471,  169,   40,
         125,  121,   64, 1008,  663,  919,  300,   85]), tensor([3, 4, 3, 2, 3, 4, 3, 4, 5, 2, 3, 5, 5, 2, 5, 4, 3, 2, 3, 5, 4, 4, 5, 4,
        5, 3, 5, 4, 2, 2, 5, 4])]
[tensor([416, 455, 327, 658, 327, 405, 666, 554, 476, 508, 305,  13,  28, 727,
        704, 919, 334, 918, 194,  91, 837, 252, 130, 269, 502, 290, 275, 864,
        630, 392, 801, 590]), tensor([ 213,  276,   99,  923,  183,  621,  653,  133,  325,  423,  923,    9,
          70,  234,  506,  660,  652, 1101,  519,  195, 1047,    7,  321,  959,
         328,  230,  118,  770,  257,  200,  332,  116]), tensor([5, 4, 4, 3, 3, 1, 4, 4, 1, 5, 5, 3

        5, 4, 3, 3, 3, 4, 2, 3])]
[tensor([807, 791, 312, 588, 937, 699, 333, 913, 887, 379, 921, 688, 177, 275,
        507, 766, 823,  62, 264, 605, 864, 429,  83, 537, 184, 391, 586, 279,
        495, 339, 367, 506]), tensor([ 298,  328,  692,  326,    9,  276,  739,   15,  405,  474,  924,  307,
         100,  515,  181,  191,  211,   71, 1475,   14,   28,   88,  584,  694,
         137,    8,  410,   99,  575,  293,  302,  137]), tensor([4, 4, 4, 4, 5, 3, 5, 3, 5, 5, 3, 4, 5, 3, 5, 4, 5, 4, 2, 5, 5, 3, 4, 4,
        5, 3, 3, 3, 3, 5, 5, 2])]
[tensor([728, 697, 354, 221, 776, 124,  73,  32, 299, 646, 479, 618, 543, 347,
        184, 569,  46,  11, 883, 108, 618, 267, 236, 798, 181, 308, 168, 405,
        594, 646, 145, 896]), tensor([ 243,  150,  904,  178,  760,   11,  286,  742,   19,  288,  111,  238,
         216,  369,  252,  100,  100,  524,  319,    7,  576,  164,  750,  769,
         983,  216,  235,  142,  520,  319, 1001,  674]), tensor([2, 5, 5, 4, 3, 5, 4, 3, 1, 3, 4, 1

        4, 5, 5, 4, 4, 5, 4, 3])]
[tensor([773, 635, 711, 894, 221, 880, 316, 350, 840, 303,  96,  31, 933, 639,
         70, 162, 717, 399, 290, 222, 363, 303, 601, 130, 774, 483, 180, 711,
        445, 385, 935, 597]), tensor([1021,  150,  250,  905,  144,  762, 1084,  427,  121,  235,  187,   79,
         228,  990,  225,  943,  298,  658,  135,  403,  761,  722, 1084,  366,
         193,  121,  790, 1289,  979,  403,  864,   15]), tensor([5, 3, 2, 3, 4, 4, 4, 5, 2, 4, 5, 2, 4, 1, 3, 4, 3, 3, 4, 3, 3, 2, 5, 5,
        5, 2, 1, 2, 2, 3, 5, 5])]
[tensor([715,   7, 727, 716, 632,  42, 870, 677,  13, 655, 111, 892, 174, 880,
        130, 452, 479, 385, 488, 294, 379, 527, 405, 575, 234,   6, 442, 478,
        934, 424, 408, 145]), tensor([  73,  651,   27,  525,  475,  925, 1208,  288,  210,  116,  302,  663,
          49,   80,  578,   94,  175,  654,  742,  322,  655,  429,  567,  322,
         623,  485,  943,   12,  225,  688,  271,  156]), tensor([4, 5, 4, 3, 3, 4, 2, 5, 3, 2, 5, 5

In [167]:
for data in train_iter:
    for i in range(32):
        user_vec = torch.Tensor(user_dict[int(data[0][i])])
        movie_vec = torch.Tensor(movie_dict[int(data[1][i])])
        x_data = torch.cat((user_vec, movie_vec))
        y_data = 
        print(x_data)

[tensor([899, 484, 623, 579, 773, 863, 318, 537, 666, 738,  48, 293, 606, 820,
        185, 119, 381, 933, 279, 254, 592, 940, 588, 880, 489, 311, 276,  92,
        854, 158, 504, 602]), tensor([  83,  597,  163,  288,   64,  898,  160,  762,  525,   89,  425,  265,
         191,  358,  114,  741,  132,   80, 1025,  472,   70,   66,   29,  179,
         875,  415,  248,  260,  628,  431,  561,  508]), tensor([4, 3, 3, 4, 4, 1, 3, 3, 4, 5, 3, 3, 5, 1, 4, 4, 5, 2, 2, 3, 4, 4, 3, 4,
        2, 3, 4, 1, 2, 5, 4, 3])]
tensor([3.2000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9930e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
 

        0.0000e+00])
tensor([4.8000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.9810e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([5.1000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9820e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([4.4000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9950e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([5.0000e+01, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([4.9000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9710e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([4.3000e+01, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9800e+03,
        0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00

        0.0000e+00])
tensor([1.3000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([2.9000e+01, 1.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

tensor([3.3000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9950e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([2.8000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00,
        0.0000e+00,

        0.0000e+00])
tensor([2.0000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([1.9000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([3.1000e+01, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9570e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([3.6000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9900e+03,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([2.1000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9920e+03,
        0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([3.5000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9970e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([3.3000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9940e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([3.3000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
tensor([3.2000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9920e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00, 1.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([2.2000e+01, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.9950e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

        0.0000e+00])
[tensor([546, 576, 130, 286, 150, 766, 622,  29,  10, 234, 336, 401, 892, 653,
        674, 407, 274, 288, 655, 393, 453, 620, 806, 442, 660, 833, 575, 401,
        664, 798, 497, 395]), tensor([ 250,  435,  147,  116,  293,  175,  480,  480,  286,  164,  619,  866,
         484,  139,  304,  169,  150,  258, 1634,  622,  210,   71,  158,  182,
         722,  730,  483,  591,  805,  576,  252,  118]), tensor([4, 4, 4, 5, 4, 3, 4, 4, 4, 3, 3, 3, 5, 2, 3, 5, 5, 4, 2, 4, 4, 5, 2, 4,
        1, 4, 3, 3, 5, 3, 3, 3])]
tensor([3.6000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9970e+03,
        0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.00

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



tensor([3.0000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9940e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00])
tensor([2.7000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 1.9690e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00, 1.0000e+00,
        0.0000e+00,

        0.0000e+00])
tensor([4.2000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9940e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00])
tensor([4.4000e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.9960e+03,
        0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00

KeyError: 267

In [124]:
for data in train_iter:
    for i in range(32):
        user_vec = user_dict[int(data[0][i])]
        movie_vec = movie_dict[int(data[1][i])]
        print(user_vec)
        print(movie_vec)
        x_data = user_vec.extend(movie_vec)
        print(x_data)

[27.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1997, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1997, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0, 1996, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1995, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1996, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1996, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1996, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]
[1995, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]
None
[25.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1993, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1996, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1989, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1997, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1996

KeyError: 1359

In [104]:
# 训练过程
for epoch in range(1):
    for data in train_iter:
        print(data[0])

tensor([768, 647, 141, 354, 338,  46, 796, 896, 474, 104,  17, 541,   7, 814,
        222, 397, 899, 889,   7, 896, 246, 109, 495,  13, 495, 880, 125, 246,
        721, 932, 298, 555])
tensor([ 75, 295, 459, 839, 172, 500, 603,  59, 504, 654, 862, 155, 786, 296,
        643, 254, 615, 291, 684, 907, 878, 787, 536, 160, 593, 320, 654, 267,
        807, 479, 405, 786])
tensor([407, 459,  89, 881, 280, 936, 880, 848,   2, 683, 682, 943, 334, 495,
        721, 548, 636, 184, 693, 177, 411, 293, 389, 778, 279, 326, 868, 830,
        312, 758, 534, 296])
tensor([707, 656, 781, 262, 916, 878, 576, 246,   1, 458, 210, 523, 405, 312,
        659, 654, 272, 667, 666, 750, 758,  16, 506, 655, 931, 624, 500, 240,
        487, 325, 149, 253])
tensor([416, 916, 181, 224, 407, 607, 408, 398, 775, 936,  11, 140, 523,  59,
        279, 932, 222, 504, 554, 844, 350, 454, 429, 741, 314, 699, 833, 600,
        839, 125, 846,  85])
tensor([705, 535, 275,  18, 633, 327, 181, 643, 867, 328, 373, 435, 294,  8

tensor([121, 450, 868, 916,  98, 269,  56, 234, 936, 760, 933, 144, 303, 851,
        288, 596, 383, 561, 659, 488,  86, 360, 248, 606, 457, 751, 612, 116,
        870, 529, 786, 653])
tensor([833, 378, 385, 195, 673, 179, 938, 511, 863, 117, 526, 246, 804, 339,
         53, 435, 764, 905, 621, 215, 850, 303, 372, 196, 638, 248, 120, 449,
        363, 648, 234, 291])
tensor([532, 709, 320, 887, 699, 479, 629, 158, 627, 761, 942, 401, 577, 864,
        896, 694,  21, 137, 198, 788, 805, 407, 319, 883, 230, 693,  79, 181,
        216, 342, 374, 204])
tensor([145, 352, 452, 236, 535, 758, 429, 625, 173, 385, 196, 480, 830, 495,
        764,  15, 120, 747, 648, 254, 201, 793, 463, 629, 504,  18, 780, 272,
        493, 361,   5, 308])
tensor([ 22, 739, 207, 901, 847,  28, 227,  32, 804, 660, 754,  71, 230, 102,
        786, 374, 588, 637, 934, 871, 332, 353, 818, 655, 189, 296, 389, 537,
        246, 234, 622, 880])
tensor([  7, 676, 620, 616, 527, 867, 828, 301,  95, 504, 794, 668, 627, 52

        913, 707,  62, 916])
tensor([295, 758,  39, 178, 122, 276, 295, 270, 503, 472, 437,  84, 627, 844,
         18, 417, 416, 535,  92, 823, 506, 696,  62, 585, 207,  76, 758, 511,
        184, 321, 141, 880])
tensor([313, 416, 643,  13, 932, 244, 684, 650,  95, 912, 450, 378, 727, 269,
        207, 639, 144, 163, 389, 851, 940,  77, 402, 488,  28, 270, 805, 222,
        268,  87, 838, 399])
tensor([121, 486, 295, 106, 178, 853, 264,  89, 932, 247, 463, 722, 591, 198,
        705, 663,  21,  49, 425, 301, 140, 263, 532, 150, 566, 436, 643, 406,
        180, 893, 454, 520])
tensor([379, 107, 246,  72, 154, 256, 505, 203, 497, 561, 553, 385, 429, 360,
        497,  13, 749,  22, 834, 271, 473, 495, 197, 327, 746, 310,  54, 295,
        178, 234, 526, 262])
tensor([411, 200, 219,  37,  94, 806, 795, 221, 790, 221, 498, 438, 551, 405,
        514, 693, 655, 569, 868, 916, 914, 378, 260, 326, 654, 653, 747, 429,
        747, 279, 807, 618])
tensor([268, 140, 326, 535, 160, 308, 693, 640

        494, 652, 593, 666])
tensor([463, 354, 447, 327, 648, 611, 864, 326, 339, 296, 261, 328, 838,  94,
        162, 780, 868, 919, 579, 263, 675, 715, 805, 692, 795, 230, 271, 588,
        227, 715, 450, 363])
tensor([378,  38, 130, 587, 599,  13, 537, 180, 161, 603, 141, 246, 308, 934,
        305, 187, 405, 360,  13, 804, 397, 308,   9, 309, 735,  59,  41,  73,
        391,  29, 275, 652])
tensor([ 90, 279,  56, 450, 279, 144, 293, 460, 396, 410, 320, 595, 920, 542,
        405, 358, 332, 488, 201, 378, 422, 307, 847, 559, 276, 828, 588, 269,
        351, 536, 207, 286])
tensor([222, 458, 764, 254, 172,  13,  32, 393, 168, 758, 314, 886, 334, 222,
        665, 405, 365, 934, 250,  57, 178, 379, 184, 101, 449, 468, 210, 470,
          7, 401, 385, 508])
tensor([814,  93, 463, 109, 379, 642, 731, 727, 334, 311, 708, 653, 222, 698,
        326,  95, 655,  23, 351, 391, 360, 117, 311, 644, 664,  95, 646, 320,
        936, 271, 939, 459])
tensor([ 42, 394, 567, 774, 472, 921,  44, 207

tensor([496, 669, 537, 385, 629, 663, 405, 259, 506, 458, 682, 181, 181, 373,
        521,  95,  77, 897, 264, 231, 344, 815, 606, 320, 749,  63, 337, 936,
        782,  92, 339, 474])
tensor([551, 601, 435, 178, 380, 269, 328,  49, 263, 401, 839, 885, 279, 307,
        488, 330, 843, 880, 339, 406, 537, 676,  52, 102, 468, 450, 694, 385,
        758, 669, 567, 291])
tensor([102, 524, 716, 676, 595, 416, 892, 788, 499, 597, 344, 405,  99, 456,
        181,  13, 826, 276, 423, 889, 563, 645, 441, 927, 328, 373, 471, 454,
        159, 905, 666, 943])
tensor([833, 854, 234,  42, 618,  11, 911, 327, 218, 321, 407, 479, 402, 919,
        644,  13, 371, 884, 846, 437, 195, 618, 451, 916, 533, 487, 181, 914,
         56, 534, 506, 339])
tensor([639, 711, 682, 429, 321, 254,  66, 385, 328, 708, 358, 792, 428, 269,
        378, 422, 881, 177, 442, 779, 705, 877, 904, 938, 711, 276, 391, 710,
        104, 497, 221, 327])
tensor([279, 514, 416, 805, 236, 727, 378,  73, 836, 655, 234, 279, 156, 87

In [None]:
# 训练过程
for epoch in range(50):
    y_pred = model()