In [1]:
import torch
from torch.utils.data import Dataset
import scipy.sparse as sp
import numpy as np
import torch.nn as nn
import argparse
import copy
import math

In [2]:
def init_seed(seed, reproducibility):
    r""" init random seed for random functions in numpy, torch, cuda and cudnn

    Args:
        seed (int): random seed
        reproducibility (bool): Whether to require reproducibility
    """
    #random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    if reproducibility:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
    else:
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = False


In [15]:
def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=100, help="rounds of training")
    parser.add_argument('--lr', type=float, default=0.8, help="learning rate")
    #parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--dim', type=int, default=20, help="latent dimension")
    parser.add_argument('--seed', type=int, default=2022, help='random seed (default: 2022)')
    parser.add_argument('--alpha', type=float, default=1.0, help='alpha')
    parser.add_argument('--Lambda', type=float, default=0.001, help='Lambda')
    parser.add_argument('--topk', type=int, default=10, help='topk')
    #parser.add_argument('--iterations', type=int, default=10, help='number of gradient descent iterations per epoch')
    parser.add_argument('--local_train_iterations', type=int, default=10, help='local_train_iterations')
    parser.add_argument('--start_hybrid_averaging_iterations', type=int, default=10, help='start_hybrid_averaging_iterations')
    parser.add_argument('--rho', type=int, default=1, help='sample items')
    parser.add_argument('--max_rating', type=float, default=5.0, help='max_rating')
    parser.add_argument('--min_rating', type=float, default=1.0, help='min_rating')
    parser.add_argument('-f', type=str, default="读取jupyter的额外参数")
    

    args = parser.parse_args()
    return args

In [16]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [17]:
def convert_sp_mat_to_sp_tensor(X):
    coo = X.tocoo().astype(np.float32)
    row = torch.Tensor(coo.row).long()
    col = torch.Tensor(coo.col).long()
    index = torch.stack([row, col])
    data = torch.FloatTensor(coo.data)
    return torch.sparse.FloatTensor(index, data, torch.Size(coo.shape))


def get_all_pos(matrix, users):
    items_pos = []
    for user in users:
        items_pos.append(matrix[user].nonzero()[1])
    return items_pos

In [18]:
def read_data(path, index=1):
    train_path = path + 'copy' + str(index) + '.train'
    test_path = path + 'copy' + str(index) + '.test'
    num_users, num_items = 0, 0
    train_user_dict = {}
    train_user_ratings = {}
    iteraction = 0
    item_set = set()
    with open(train_path, 'r') as f:
        for line in f.readlines():
            line = line.split()
            uid = int(line[0]) - 1
            iid = int(line[1]) - 1
            if train_user_dict.get(uid) == None:
                train_user_dict[uid] = []
                train_user_ratings[uid] = []
            train_user_dict[uid].append(iid)
            train_user_ratings[uid].append(float(line[2]))
            item_set.add(iid)
            iteraction += 1
    test_user_dict = {}
    test_user_ratings = {}
    with open(test_path, 'r') as f:
        for line in f.readlines():
            line = line.split()
            uid = int(line[0]) - 1
            if test_user_dict.get(uid) == None:
                test_user_dict[uid] = []
                test_user_ratings[uid] = []
            test_user_dict[uid].append(int(line[1]) - 1)
            test_user_ratings[uid].append(float(line[2]))
            item_set.add(int(line[1]) - 1)
    print(f'共有{len(train_user_dict)}个用户，{len(item_set)}个物品，交互总数为{iteraction}')
    return train_user_dict, train_user_ratings, test_user_dict, test_user_ratings, len(train_user_dict), len(item_set)

In [19]:
class FCF_client():
    def __init__(self, uid, I, I_u, I_u_ratings, args):
        super().__init__()
        self.uid = uid
        self.user_embedding = ((torch.rand(args.dim) - 0.5) * 0.01).to(device)
        #self.user_embedding = torch.randn(args.dim).to(device)
        self.args = args
        self.lr = self.args.lr
        self.I_u = np.array(I_u)
        self.I_u_ratings = torch.tensor(I_u_ratings).to(device)
        self.avg_r = np.sum(I_u_ratings) / self.I_u_ratings.shape[0]
        self.I_u_sample = np.delete(I, self.I_u - 1)
        self.iter = 0
    
    def train(self, item_embeddings):
        loss = 0.0
        u_item_embeddings = item_embeddings[self.I_u]
        pred = self.user_embedding @ u_item_embeddings.T
        err = self.I_u_ratings - pred
        loss += torch.sum(err.pow(2)).item()
        grad_u = -(err.reshape(-1, 1) * u_item_embeddings).sum(0) + self.args.Lambda * self.I_u.shape[0] * self.user_embedding
        grad_i = -err.reshape(-1, 1) * self.user_embedding + self.args.Lambda * u_item_embeddings
        temp_u = 0
        
        
        if self.args.rho != 0:
            if self.iter > self.args.start_hybrid_averaging_iterations:
                temp_u = self.user_embedding.clone()
                for it in range(self.args.local_train_iterations):
                    pred = temp_u @ u_item_embeddings.T
                    err = self.I_u_ratings - pred
                    temp_grad_u = -(err.reshape(-1, 1) * u_item_embeddings).sum(0) + self.args.Lambda * self.I_u.shape[0] * temp_u
                    temp_u = temp_u - self.lr * temp_grad_u / self.I_u.shape[0]
        
        
        sample_number = self.args.rho * self.I_u.shape[0] if self.args.rho * self.I_u.shape[0] < self.I_u_sample.shape[0] else self.I_u_sample.shape[0]
                                     
        np.random.shuffle(self.I_u_sample)
        update_list = np.append(self.I_u, self.I_u_sample[:sample_number])
        
        sample_item_emb = item_embeddings[self.I_u_sample[:sample_number]]
        pred = self.user_embedding @ sample_item_emb.T
        if self.iter > self.args.start_hybrid_averaging_iterations:
            temp_pred = temp_u @ sample_item_emb.T
            err = temp_pred - pred
        else:
            err = torch.tensor([self.avg_r] * sample_number).to(device) - pred
        loss += torch.sum(err.pow(2)).item()
        grad_u += -(err.reshape(-1, 1) * sample_item_emb).sum(0) + self.args.Lambda * sample_number * self.user_embedding
        grad_sample_i = -err.reshape(-1, 1) * self.user_embedding + self.args.Lambda * sample_item_emb
        self.user_embedding -= self.lr * grad_u / (self.I_u.shape[0] + sample_number)
        self.lr *= 0.9
        self.iter += 1
        #print(self.user_embedding)
        return loss, update_list, torch.cat((grad_i, grad_sample_i), dim=0)

In [20]:
class FCF_server(nn.Module):
    def __init__(self, args, num_items):
        super().__init__()
        self.args = args
        self.num_items = num_items
        self.item_embeddings = ((torch.rand(num_items, args.dim)-0.5)*0.01).to(device)
        #self.item_embeddings = torch.randn(num_items, args.dim).to(device)
        self.lr = self.args.lr
        
    def update(self, sum_grad_i, count_i):
        count_i[count_i < 1] = 1
        count_i = torch.tensor(count_i).reshape(-1,1).to(device)
        self.item_embeddings -= self.lr * sum_grad_i / count_i
        self.lr *= 0.9
        

In [21]:
def test(test_user_dict, test_user_ratings, server_model, users_model, args):
    with torch.no_grad():
        rmse = 0.0
        mae = 0.0
        item_cnt = 0
        for user in test_user_dict:
            items = test_user_dict[user]
            ratings = torch.tensor(test_user_ratings[user]).to(device)
            pred = users_model[user].user_embedding @ server_model.item_embeddings[items].T

            pred[pred < args.min_rating] = args.min_rating
            pred[pred > args.max_rating] = args.max_rating
            rmse += torch.sum((ratings - pred).pow(2))
            mae += torch.sum((ratings - pred).abs())
            item_cnt += len(items)
        rmse = math.sqrt(rmse / item_cnt)
        mae = mae / item_cnt
        return rmse, mae

In [22]:
def trainer(index=5, path='./ML100K/'):
    args = args_parser()
    init_seed(args.seed, True)
    rmse_l, mae_l = [], []
    for idx in range(1, 1+index):
        rmse_min, mae_min = 1<<20, 1<<20
        args = args_parser()
        train_user_dict, train_user_ratings, test_user_dict, test_user_ratings, num_users, num_items = read_data(path, idx)
        users_model = []
        server_model = FCF_server(args, num_items)
        I = np.arange(0, num_items)
        for uid in range(num_users):
            users_model.append(FCF_client(uid, I, train_user_dict[uid], train_user_ratings[uid], args))
        for epoch in range(args.epochs):
            sum_loss = 0.0
            cnt_items = 0
            sum_grad_i = torch.zeros(num_items, args.dim).to(device)
            count_i = np.zeros(num_items)
            for uid in range(num_users):
            #for uid in range(1):
                loss, update_list, grad_i = users_model[uid].train(server_model.item_embeddings)
                sum_loss += loss
                sum_grad_i[update_list] += grad_i
                count_i[update_list] += 1
                cnt_items += update_list.shape[0]
            server_model.update(sum_grad_i, count_i)
            sum_loss = math.sqrt(sum_loss / cnt_items)
            print(f'Epoch {epoch+1}/{args.epochs}: rmse {sum_loss:.8f}')
            rmse, mae = test(test_user_dict, test_user_ratings, server_model, users_model, args)
            print(f'Test rmse {rmse:.8f}, mae {mae:.8f}')
            rmse_min = min(rmse_min, rmse)
            mae_min = min(mae_min, mae.item())
        print(f'Best rmse {rmse_min:.8f}, mae {mae_min:.8f}')
        rmse_l.append(rmse_min)
        mae_l.append(mae_min)
    rmse_l = np.array(rmse_l)
    mae_l = np.array(mae_l)
    print(f'RMSE: {np.mean(rmse_l):.5f}±{np.std(rmse_l):.5f}')
    print(f'MAE : {np.mean(mae_l):.5f}±{np.std(mae_l):.5f}')

In [None]:
trainer(5, './ML100K/')

共有943个用户，1682个物品，交互总数为80000
Epoch 1/100: rmse 3.63410450
Test rmse 2.76550172, mae 2.52370000
Epoch 2/100: rmse 3.63410041
Test rmse 2.76550172, mae 2.52370000
Epoch 3/100: rmse 3.63407922
Test rmse 2.76550172, mae 2.52370000
Epoch 4/100: rmse 3.63387386
Test rmse 2.76550172, mae 2.52370000
Epoch 5/100: rmse 3.63198477
Test rmse 2.76550172, mae 2.52370000
Epoch 6/100: rmse 3.61669804
Test rmse 2.76550172, mae 2.52370000
Epoch 7/100: rmse 3.50923950
Test rmse 2.76086853, mae 2.52000904
Epoch 8/100: rmse 2.87845674
Test rmse 1.00564140, mae 0.80894482
Epoch 9/100: rmse 0.74441402
Test rmse 1.04936310, mae 0.80666095
Epoch 10/100: rmse 0.76490921
Test rmse 1.05297815, mae 0.86717945
Epoch 11/100: rmse 0.81044988
Test rmse 1.06066759, mae 0.81298363
Epoch 12/100: rmse 0.78517563
Test rmse 1.02417135, mae 0.83974761
Epoch 13/100: rmse 0.74494285
Test rmse 0.97900056, mae 0.75910318
Epoch 14/100: rmse 0.67965139
Test rmse 0.96037366, mae 0.76416153
Epoch 15/100: rmse 0.65744234
Test rmse 0.9

Test rmse 0.93718431, mae 0.73747271
Epoch 23/100: rmse 0.65010267
Test rmse 0.93706719, mae 0.73734134
Epoch 24/100: rmse 0.64994288
Test rmse 0.93704419, mae 0.73725879
Epoch 25/100: rmse 0.64980774
Test rmse 0.93698223, mae 0.73719960
Epoch 26/100: rmse 0.64969440
Test rmse 0.93692571, mae 0.73713130
Epoch 27/100: rmse 0.64959615
Test rmse 0.93690866, mae 0.73709786
Epoch 28/100: rmse 0.64951292
Test rmse 0.93689855, mae 0.73706269
Epoch 29/100: rmse 0.64944200
Test rmse 0.93688522, mae 0.73704094
Epoch 30/100: rmse 0.64938098
Test rmse 0.93688951, mae 0.73702860
Epoch 31/100: rmse 0.64932923
Test rmse 0.93687832, mae 0.73703128
Epoch 32/100: rmse 0.64928467
Test rmse 0.93687310, mae 0.73700333
Epoch 33/100: rmse 0.64924394
Test rmse 0.93685077, mae 0.73698461
Epoch 34/100: rmse 0.64920827
Test rmse 0.93683505, mae 0.73695880
Epoch 35/100: rmse 0.64917646
Test rmse 0.93682595, mae 0.73694116
Epoch 36/100: rmse 0.64914940
Test rmse 0.93682408, mae 0.73693389
Epoch 37/100: rmse 0.6491

Epoch 44/100: rmse 0.64770329
Test rmse 0.94560656, mae 0.74652588
Epoch 45/100: rmse 0.64769426
Test rmse 0.94559768, mae 0.74651694
Epoch 46/100: rmse 0.64768566
Test rmse 0.94559128, mae 0.74651289
Epoch 47/100: rmse 0.64767797
Test rmse 0.94558702, mae 0.74650538
Epoch 48/100: rmse 0.64767148
Test rmse 0.94558472, mae 0.74650764
Epoch 49/100: rmse 0.64766541
Test rmse 0.94558611, mae 0.74650645
Epoch 50/100: rmse 0.64766029
Test rmse 0.94557886, mae 0.74649876
Epoch 51/100: rmse 0.64765546
Test rmse 0.94557502, mae 0.74649680
Epoch 52/100: rmse 0.64765121
Test rmse 0.94557158, mae 0.74649107
Epoch 53/100: rmse 0.64764748
Test rmse 0.94556679, mae 0.74648619
Epoch 54/100: rmse 0.64764388
Test rmse 0.94556225, mae 0.74648541
Epoch 55/100: rmse 0.64764058
Test rmse 0.94556386, mae 0.74648595
Epoch 56/100: rmse 0.64763777
Test rmse 0.94555941, mae 0.74648184
Epoch 57/100: rmse 0.64763530
Test rmse 0.94555529, mae 0.74648017
Epoch 58/100: rmse 0.64763313
Test rmse 0.94555522, mae 0.7464