In [15]:
import pickle, time, os
import numpy as np
import paddle
import pandas as pd
import paddle.nn as nn
from tqdm import tqdm

from visualdl import LogWriter
logwriter = LogWriter(logdir='./runs')
# visualdl --logdir ./runs/ --host 0.0.0.0 --port 8040

users_df = pd.read_csv('data/csv/users.csv')
items_df = pd.read_csv('data/csv/items.csv')

In [2]:
emb_scale = 64
batch_size = 2048
len_users = len(users_df)
len_items = len(items_df)

In [3]:
class Dataset(paddle.io.Dataset):
    def __init__(self, data):
        self.data = data
        pass
    def __getitem__(self, idx):
        return self.data[idx][0:2], self.data[idx][2]
    def __len__(self):
        return len(self.data)
    pass 

In [29]:
class Net(nn.Layer):
    def __init__(self):
        super(Net,self).__init__()
        self.users_emb = nn.Embedding(len_users+1, emb_scale)
        self.items_emb = nn.Embedding(len_items+1, emb_scale)
        # self.user_fc1 = nn.Embedding(emb_scale, 64)
        # self.item_fc1 = nn.Embedding(emb_scale, 64)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.cos = nn.CosineSimilarity()
        pass
    def forward(self, input):
        user = self.users_emb(input[:,0])
        item = self.items_emb(input[:,1])
        # user = self.user_fc1(user)
        # item = self.item_fc1(item)
        x = self.cos(user, item)
        x = self.sigmoid(x)
        return x
    pass
net = Net()
optim = paddle.optimizer.Adam(parameters=net.parameters(), learning_rate=0.005, weight_decay=paddle.regularizer.L2Decay(1e-1))
m = paddle.metric.Recall()

In [28]:
train_loss_, train_acc_, eval_loss_, eval_acc_ = [], [], [], []

for epoch_id in range(200):

    train_loss, train_acc, eval_loss, eval_acc = [], [], [], []

    file_list = os.listdir('data/net/train/')
    file = open('data/net/train/'+np.random.choice(file_list), 'rb')
    train_dataset = paddle.io.DataLoader(Dataset(pickle.load(file)),
                                         drop_last=True,
                                         batch_size=batch_size)
    file.close()    

    net.train()
    m.reset()
    for batch_id, data in enumerate(tqdm(train_dataset)):
        x_data = data[0]
        y_data = data[1]
        x_data = paddle.cast(x_data, dtype='int32')

        y_pred = net(x_data)
        loss = nn.functional.mse_loss(y_pred, paddle.cast(y_data, dtype='float32'))
        acc = paddle.static.accuracy(paddle.reshape(y_pred, (batch_size, 1)), paddle.reshape(paddle.cast(y_data, dtype='int64'), (batch_size, 1)))
        m.update(y_pred, y_data)
        loss.backward()

        optim.step()
        optim.clear_grad()
        train_loss.append(loss.numpy())
        train_acc.append(acc.numpy())

        logwriter.add_scalar("train_loss", value=loss.numpy(), step=batch_id+epoch_id*(batch_size))
        logwriter.add_scalar("train_acc", value=acc.numpy(), step=batch_id+epoch_id*(batch_size))
        logwriter.add_scalar("train_recall", value=m.accumulate(), step=batch_id+epoch_id*(batch_size))
        
        if m.accumulate() > 0.3:
            break

    file_list = os.listdir('data/net/eval/')
    file = open('data/net/eval/'+np.random.choice(file_list), 'rb')
    train_dataset = paddle.io.DataLoader(Dataset(pickle.load(file)),
                                         drop_last=True,
                                         batch_size=batch_size,
                                         shuffle=True)
    file.close() 

    net.eval()
    m.reset()
    for batch_id, data in enumerate(tqdm(train_dataset)):
        x_data = data[0]
        y_data = data[1]
        x_data = paddle.cast(x_data, dtype='int32')

        y_pred = net(x_data)
        loss = nn.functional.mse_loss(y_pred, paddle.cast(y_data, dtype='float32'))
        acc = paddle.static.accuracy(paddle.reshape(y_pred, (batch_size, 1)), paddle.reshape(paddle.cast(y_data, dtype='int64'), (batch_size, 1)))
        m.update(y_pred, y_data)

        logwriter.add_scalar("eval_loss", value=loss.numpy(), step=batch_id+epoch_id*(batch_size))
        logwriter.add_scalar("eval_acc", value=acc.numpy(), step=batch_id+epoch_id*(batch_size))
        logwriter.add_scalar("eval_recall", value=m.accumulate(), step=batch_id+epoch_id*(batch_size))

    train_loss_.append(train_loss)
    train_acc_.append(train_acc)
    eval_loss_.append(eval_loss)
    eval_acc_.append(eval_acc)
    print("epoch_id: {}, batch_id: {}, loss: {}, acc: {}, recall: {}".format(epoch_id, batch_id+1, loss.numpy(), acc.numpy(), m.accumulate()))

  0%|          | 1/839 [00:00<05:19,  2.63it/s]
100%|██████████| 209/209 [00:05<00:00, 38.44it/s]


epoch_id: 0, batch_id: 209, loss: [0.2506001], acc: [0.99853516], recall: 0.4859038142620232


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 30.35it/s]


epoch_id: 1, batch_id: 209, loss: [0.25137556], acc: [0.9975586], recall: 0.5155963302752293


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:04<00:00, 42.72it/s]


epoch_id: 2, batch_id: 209, loss: [0.25135118], acc: [0.99902344], recall: 0.5041736227045075


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 37.02it/s]


epoch_id: 3, batch_id: 209, loss: [0.25088477], acc: [0.99853516], recall: 0.4959016393442623


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 30.04it/s]


epoch_id: 4, batch_id: 209, loss: [0.25050804], acc: [1.], recall: 0.5134529147982063


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:07<00:00, 27.22it/s]


epoch_id: 5, batch_id: 209, loss: [0.25067317], acc: [0.99902344], recall: 0.5178571428571429


  0%|          | 1/839 [00:00<04:37,  3.02it/s]
100%|██████████| 209/209 [00:06<00:00, 33.57it/s]


epoch_id: 6, batch_id: 209, loss: [0.25022116], acc: [0.99902344], recall: 0.4774590163934426


  0%|          | 3/839 [00:00<02:12,  6.32it/s]
100%|██████████| 209/209 [00:05<00:00, 39.98it/s]


epoch_id: 7, batch_id: 209, loss: [0.25055504], acc: [0.99902344], recall: 0.45739910313901344


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 35.52it/s]


epoch_id: 8, batch_id: 209, loss: [0.2508282], acc: [0.9995117], recall: 0.47651006711409394


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 37.09it/s]


epoch_id: 9, batch_id: 209, loss: [0.25236607], acc: [0.99902344], recall: 0.5454545454545454


  0%|          | 1/839 [00:00<04:05,  3.42it/s]
100%|██████████| 209/209 [00:07<00:00, 27.29it/s]


epoch_id: 10, batch_id: 209, loss: [0.2514472], acc: [0.99902344], recall: 0.4864864864864865


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 31.51it/s]


epoch_id: 11, batch_id: 209, loss: [0.25074685], acc: [0.99853516], recall: 0.5645645645645646


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 38.91it/s]


epoch_id: 12, batch_id: 209, loss: [0.25078884], acc: [0.99902344], recall: 0.5449101796407185


  0%|          | 1/839 [00:00<04:21,  3.21it/s]
100%|██████████| 209/209 [00:06<00:00, 31.63it/s]


epoch_id: 13, batch_id: 209, loss: [0.2497801], acc: [0.9980469], recall: 0.5391566265060241


  0%|          | 1/839 [00:00<03:46,  3.69it/s]
100%|██████████| 209/209 [00:05<00:00, 36.77it/s]


epoch_id: 14, batch_id: 209, loss: [0.25034353], acc: [1.], recall: 0.48692403486924035


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:07<00:00, 28.30it/s]


epoch_id: 15, batch_id: 209, loss: [0.2511642], acc: [1.], recall: 0.5017123287671232


  0%|          | 3/839 [00:00<03:48,  3.67it/s]
100%|██████████| 209/209 [00:07<00:00, 27.83it/s]


epoch_id: 16, batch_id: 209, loss: [0.2515879], acc: [0.99902344], recall: 0.5266903914590747


  0%|          | 1/839 [00:00<03:48,  3.67it/s]
100%|██████████| 209/209 [00:06<00:00, 30.92it/s]


epoch_id: 17, batch_id: 209, loss: [0.25179526], acc: [0.99853516], recall: 0.5195729537366548


  0%|          | 3/839 [00:00<02:56,  4.73it/s]
100%|██████████| 209/209 [00:05<00:00, 37.14it/s]


epoch_id: 18, batch_id: 209, loss: [0.2516933], acc: [0.99658203], recall: 0.4939965694682676


  0%|          | 1/839 [00:00<04:40,  2.98it/s]
100%|██████████| 209/209 [00:06<00:00, 33.90it/s]


epoch_id: 19, batch_id: 209, loss: [0.24992093], acc: [0.99902344], recall: 0.49452554744525545


  0%|          | 2/839 [00:00<04:32,  3.07it/s]
100%|██████████| 209/209 [00:06<00:00, 33.02it/s]


epoch_id: 20, batch_id: 209, loss: [0.25118285], acc: [0.99902344], recall: 0.49110320284697506


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 39.59it/s]


epoch_id: 21, batch_id: 209, loss: [0.2504844], acc: [0.99658203], recall: 0.5173674588665448


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 34.69it/s]


epoch_id: 22, batch_id: 209, loss: [0.250972], acc: [0.9980469], recall: 0.5439560439560439


  1%|          | 5/839 [00:00<02:05,  6.67it/s]
100%|██████████| 209/209 [00:05<00:00, 35.80it/s]


epoch_id: 23, batch_id: 209, loss: [0.25137788], acc: [1.], recall: 0.4856115107913669


  0%|          | 1/839 [00:00<04:02,  3.45it/s]
100%|██████████| 209/209 [00:06<00:00, 30.42it/s]


epoch_id: 24, batch_id: 209, loss: [0.2516216], acc: [0.99902344], recall: 0.5045703839122486


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 35.72it/s]


epoch_id: 25, batch_id: 209, loss: [0.24977107], acc: [0.99658203], recall: 0.48941469489414696


  0%|          | 1/839 [00:00<03:55,  3.56it/s]
100%|██████████| 209/209 [00:05<00:00, 37.73it/s]


epoch_id: 26, batch_id: 209, loss: [0.25110215], acc: [0.9995117], recall: 0.5


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 31.95it/s]


epoch_id: 27, batch_id: 209, loss: [0.2503917], acc: [0.9980469], recall: 0.46948941469489414


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 34.94it/s]


epoch_id: 28, batch_id: 209, loss: [0.25048104], acc: [0.99902344], recall: 0.5282392026578073


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:07<00:00, 26.78it/s]


epoch_id: 29, batch_id: 209, loss: [0.25152344], acc: [0.99902344], recall: 0.5021337126600285


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 38.97it/s]


epoch_id: 30, batch_id: 209, loss: [0.25008005], acc: [0.9970703], recall: 0.4928977272727273


  0%|          | 1/839 [00:00<05:00,  2.79it/s]
100%|██████████| 209/209 [00:05<00:00, 34.85it/s]


epoch_id: 31, batch_id: 209, loss: [0.25042003], acc: [0.9995117], recall: 0.4887525562372188


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 35.46it/s]


epoch_id: 32, batch_id: 209, loss: [0.25071797], acc: [0.9980469], recall: 0.5103448275862069


  0%|          | 2/839 [00:00<02:57,  4.71it/s]
100%|██████████| 209/209 [00:06<00:00, 32.13it/s]


epoch_id: 33, batch_id: 209, loss: [0.25079978], acc: [0.99902344], recall: 0.49166666666666664


  0%|          | 1/839 [00:00<13:43,  1.02it/s]
100%|██████████| 209/209 [00:05<00:00, 34.89it/s]


epoch_id: 34, batch_id: 209, loss: [0.25118792], acc: [0.99658203], recall: 0.5174708818635607


  0%|          | 1/839 [00:00<04:05,  3.41it/s]
100%|██████████| 209/209 [00:05<00:00, 37.32it/s]


epoch_id: 35, batch_id: 209, loss: [0.24733633], acc: [0.9951172], recall: 0.45454545454545453


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 32.88it/s]


epoch_id: 36, batch_id: 209, loss: [0.2502373], acc: [0.99902344], recall: 0.515625


  1%|▏         | 12/839 [00:01<01:44,  7.89it/s]
100%|██████████| 209/209 [00:05<00:00, 35.62it/s]


epoch_id: 37, batch_id: 209, loss: [0.25092247], acc: [0.99853516], recall: 0.5018315018315018


  0%|          | 1/839 [00:00<06:40,  2.09it/s]
100%|██████████| 209/209 [00:07<00:00, 28.31it/s]


epoch_id: 38, batch_id: 209, loss: [0.25095615], acc: [1.], recall: 0.4975124378109453


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:05<00:00, 38.84it/s]


epoch_id: 39, batch_id: 209, loss: [0.25058264], acc: [0.99853516], recall: 0.5210237659963437


  0%|          | 0/839 [00:00<?, ?it/s]
100%|██████████| 209/209 [00:06<00:00, 34.57it/s]


epoch_id: 40, batch_id: 209, loss: [0.24994406], acc: [0.99853516], recall: 0.5100671140939598


  0%|          | 1/839 [00:00<04:37,  3.02it/s]
100%|██████████| 209/209 [00:06<00:00, 32.13it/s]


epoch_id: 41, batch_id: 209, loss: [0.25005758], acc: [0.99853516], recall: 0.48036253776435045


  0%|          | 0/839 [00:00<?, ?it/s]
 88%|████████▊ | 184/209 [00:06<00:00, 29.70it/s]


KeyboardInterrupt: 