In [1]:
'''
movies.csv
movieId,title,genres
1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy
2,Jumanji (1995),Adventure|Children|Fantasy
3,Grumpier Old Men (1995),Comedy|Romance
4,Waiting to Exhale (1995),Comedy|Drama|Romance
5,Father of the Bride Part II (1995),Comedy
6,Heat (1995),Action|Crime|Thriller
7,Sabrina (1995),Comedy|Romance
8,Tom and Huck (1995),Adventure|Children
9,Sudden Death (1995),Action
10,GoldenEye (1995),Action|Adventure|Thriller
11,"American President, The (1995)",Comedy|Drama|Romance

ratings
userId,movieId,rating,timestamp
1,1,4.0,964982703
1,3,4.0,964981247
1,6,4.0,964982224
1,47,5.0,964983815
1,50,5.0,964982931
1,70,3.0,964982400
'''

'\nmovies.csv\nmovieId,title,genres\n1,Toy Story (1995),Adventure|Animation|Children|Comedy|Fantasy\n2,Jumanji (1995),Adventure|Children|Fantasy\n3,Grumpier Old Men (1995),Comedy|Romance\n4,Waiting to Exhale (1995),Comedy|Drama|Romance\n5,Father of the Bride Part II (1995),Comedy\n6,Heat (1995),Action|Crime|Thriller\n7,Sabrina (1995),Comedy|Romance\n8,Tom and Huck (1995),Adventure|Children\n9,Sudden Death (1995),Action\n10,GoldenEye (1995),Action|Adventure|Thriller\n11,"American President, The (1995)",Comedy|Drama|Romance\n\nratings\nuserId,movieId,rating,timestamp\n1,1,4.0,964982703\n1,3,4.0,964981247\n1,6,4.0,964982224\n1,47,5.0,964983815\n1,50,5.0,964982931\n1,70,3.0,964982400\n'

In [2]:
import os

In [3]:
movies_file = "F:\\badou\\tmp\\data\\recommender\\data\\ml-latest-small\\movies.csv"
ratings_file = "F:\\badou\\tmp\\data\\recommender\\data\\ml-latest-small\\ratings.csv"

In [4]:
# movies
def get_item_info(input_file):
    # 判断文件是否存在
    if not os.path.exists(input_file):
        return {}
    item_info = {}
    linenum = 0
    fp = open(input_file, "r", encoding = "utf-8")
    for line in fp:
        # 跳过表头
        if linenum == 0:
            linenum += 1
            continue
        item = line.strip().split(",")
        if len(item) < 3:
            continue
        elif len(item) == 3:
            itemid, title, genre = item[0], item[1], item[2]
        elif len(item) > 3:
            itemid = item[0]
            # 类别
            genre = item[-1]
            # 有些title中也有,
            title = ",".join(item[1:-1])
        item_info[itemid] = [title, genre]
    fp.close()
    return item_info

In [5]:
# ratings
def get_ave_score(input_file):
    if not os.path.exists(input_file):
        return {}
    line_num = 0
    # 中间变量，存储item一共被多少人点评过
    record_dict = {}
    # 评分数据
    score_dict = {}
    fp = open(input_file, "r", encoding = "utf-8")
    for line in fp:
        if line_num == 0:
            line_num += 1
            continue
        item = line.strip().split(",")
        if len(item) < 4:
            continue
        userid, itemid, rating = item[0], item[1], item[2]
        # 记录item被多少人打分过
        if itemid not in record_dict:
            record_dict[itemid] = [0, 0]
        # 计数
        record_dict[itemid][0] += 1
        # 打分累加
        record_dict[itemid][1] += float(rating)
    fp.close()
    # 将结果存入输出的数据结构中
    for itemid in record_dict:
        # 保留三位有效数字(求平均)
        score_dict[itemid] = round(record_dict[itemid][1] / record_dict[itemid][0], 3)
    return score_dict

In [6]:
# train
def get_train_data(input_file):
    '''
    特征工程，准备训练数据
    :param input_file:
        ratings_file: userid itemid rating
    :return:
        a list[(userid, itemid, label), (userid1, itemid1, label)]
    '''
    if not os.path.exists(input_file):
        return []
    # 每一个item的平均得分
    score_dict = get_ave_score(input_file)
    # 负例
    neg_dict = {}
    # 正例
    pos_dict = {}
    # 训练数据集
    train_data = []
    # 正负样本阈值（打分）
    score_thr = 4.0
    # 跳过表头
    linenum = 0
    fp = open(input_file, "r", encoding = "utf-8")
    for line in fp:
        if linenum == 0:
            linenum += 1
            continue
        item = line.strip().split(",")
        if len(item) < 4:
            continue
        userid, itemid, rating = item[0], item[1], float(item[2])
        # 得分 >= 4 喜欢 1 else 0 label
        if userid not in pos_dict:
            pos_dict[userid] = []
        if userid not in neg_dict:
            neg_dict[userid] = []
        if rating >= score_thr:
            # 存储为元组的格式 tuple
            pos_dict[userid].append((itemid, 1))
        else:
            # 需要对负样本进行负采样
            # 先存储平均打分，没有则为0
            score = score_dict.get(itemid, 0)
            neg_dict[userid].append((itemid, score))
    fp.close()
    
    # 正负样本的均衡和负采样
    for userid in pos_dict:
        # [] 和 get() 功能相同
        data_num = min(len(pos_dict[userid]), len(neg_dict.get(userid, [])))
        # 正负样本的数目大于0, 先将正样本存储到输出的数据结构中同时限定数据量(data_num)
        if data_num > 0:
            # userid, itemid, label
            train_data += [(userid, zuhe[0], zuhe[1]) for zuhe in pos_dict[userid]][:data_num]
        else:
            continue
            
        # 对userid所对应的负样本的数据进行排序(按照得分逆序排列取 训练样本 个数个)
        sorted_neg_list = sorted(neg_dict[userid], key = lambda e: e[1], reverse = True)[:data_num]
        # 将负样本加入训练样本中
        train_data += [(userid, zuhe[0], 0) for zuhe in sorted_neg_list]
        
        #  查看userid 1 的正负样本比例以及排序后的数据
#         if userid == "1":
#             # 200 个正样本
#             print(len(pos_dict[userid]))
#             # 32 个负样本
#             print(len(neg_dict[userid]))
#             # 评分由高到低
#             print(sorted_neg_list)
        
    return train_data

In [7]:
train_data = get_train_data(ratings_file)

In [8]:
print(len(train_data))
print(train_data[:40])

67186
[('1', '1', 1), ('1', '3', 1), ('1', '6', 1), ('1', '47', 1), ('1', '50', 1), ('1', '101', 1), ('1', '110', 1), ('1', '151', 1), ('1', '157', 1), ('1', '163', 1), ('1', '216', 1), ('1', '231', 1), ('1', '235', 1), ('1', '260', 1), ('1', '333', 1), ('1', '349', 1), ('1', '356', 1), ('1', '362', 1), ('1', '367', 1), ('1', '441', 1), ('1', '457', 1), ('1', '480', 1), ('1', '527', 1), ('1', '543', 1), ('1', '552', 1), ('1', '553', 1), ('1', '590', 1), ('1', '592', 1), ('1', '593', 1), ('1', '596', 1), ('1', '608', 1), ('1', '661', 1), ('1', '296', 0), ('1', '1258', 0), ('1', '1219', 0), ('1', '223', 0), ('1', '1408', 0), ('1', '648', 0), ('1', '70', 0), ('1', '1580', 0)]
