In [2]:
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data import Dataset

import torch_geometric
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.utils import degree

from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
columns_name = ['place_index', 'user_index', 'rating']
review_df = pd.read_csv("train.tsv", sep="\t")[columns_name].astype(int)

In [5]:
max_user_id = review_df['user_index'].max()
max_place_id = review_df['place_index'].max()


In [6]:
max_node_id = max_user_id + max_place_id + 1 # since place_id starts from 0

In [7]:
train, test = train_test_split(review_df.values, test_size=0.1)
train_df = pd.DataFrame(train, columns=review_df.columns)
test_df = pd.DataFrame(test, columns=review_df.columns)

In [8]:
# Weights will be used to normalize loss function
def get_weights(df):
    rating_counts = np.array([len(df[df['rating'] == i]) for i in [1, 2, 3, 4, 5]])
    inverse_count = 1 / rating_counts
    norm = np.linalg.norm(inverse_count)
    normalized_inverse_count = inverse_count / norm

    return normalized_inverse_count

weights = get_weights(train_df)
print(weights)

[0.59525181 0.75159322 0.25561637 0.11561983 0.04555491]


In [10]:
train_df['weight'] = train_df['rating'].map(lambda val: weights[int(val)-1])
test_df['weight'] = test_df['rating'].map(lambda val: weights[int(val)-1])

In [11]:
# Check data snippet
train_df.head(5)

Unnamed: 0,place_index,user_index,rating,weight
0,185,46197,4,0.11562
1,4824,27755,5,0.045555
2,1092,1376,3,0.255616
3,4721,10935,4,0.11562
4,852,16080,1,0.595252


In [12]:
class MyDataset(Dataset):
    def __init__(self, data):
        self.data = data.to_numpy()
        
    def __getitem__(self, index):
        return self.data[index, 0].astype(np.compat.long), \
            self.data[index, 1].astype(np.compat.long), \
            self.data[index, 2:3].astype(np.float32), \
            self.data[index, 3]
    
    def __len__(self):
        return len(self.data)

In [13]:
u_t = torch.LongTensor(train_df.user_index.to_numpy())
p_t = torch.LongTensor(train_df.place_index.to_numpy()) + max_user_id + 1

train_edge_index = torch.stack((torch.cat([u_t, p_t]),torch.cat([p_t, u_t]))).to(device)

In [14]:
train_df['place_index'] = train_df['place_index'] + max_user_id + 1
test_df['place_index'] = test_df['place_index'] + max_user_id + 1
# assert that there's no index overlapping
intersection = set(train_df['place_index'].unique()).intersection(set(train_df['user_index'].unique()))
assert len(intersection) == 0

intersection = set(test_df['place_index'].unique()).intersection(set(test_df['user_index'].unique()))
assert len(intersection) == 0

train_dataset = MyDataset(train_df)
test_dataset = MyDataset(test_df)

In [15]:
class LightGCNConv(MessagePassing):
    def __init__(self, **kwargs):
        super().__init__(aggr='add')

    def forward(self, x, edge_index, num_nodes):
        # Compute normalization
        from_, to_ = edge_index
        deg = degree(to_, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
        norm = deg_inv_sqrt[from_] * deg_inv_sqrt[to_]
        # Start propagating messages (no update after aggregation)
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        return norm.view(-1, 1) * x_j

In [16]:
# Initialize node embeddings as one-hot embeddings
test_x = torch.Tensor(np.eye(5))

# Construct edges
test_edge_index = torch.LongTensor(np.array([
  [0, 0, 1, 1, 2, 3, 3, 4],
  [2, 3, 3, 4, 0, 0, 1, 1]
]))

# Check out the result of passing the embeddings through our Graph Convolutional Network
LightGCNConv()(test_x, test_edge_index, 5)

tensor([[0.0000, 0.0000, 0.7071, 0.5000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.5000, 0.7071],
        [0.7071, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.7071, 0.0000, 0.0000, 0.0000]])

In [17]:
class LightGCN(nn.Module):
    def __init__(self, latent_dim, num_layers, max_index):
        super(LightGCN, self).__init__()
        self.embedding = nn.Embedding(max_index, latent_dim)
        self.convs = nn.ModuleList(LightGCNConv() for _ in range(num_layers))
        self.init_parameters()
        self.nn = nn.Linear(2*latent_dim, 1)

        self.max_index = max_index

    def init_parameters(self):
        nn.init.normal_(self.embedding.weight, std=0.1) 

    def forward(self, edge_index):
        emb0 = self.embedding.weight
        embs = [emb0]
        emb = emb0
        for conv in self.convs:
            emb = conv(x=emb, edge_index=edge_index, num_nodes=self.max_index)
            embs.append(emb)

        out = torch.mean(torch.stack(embs, dim=0), dim=0)
        return emb0, out
    
    def pred(self, users, items, embeddings):
        user_emb = embeddings[users]
        item_emb = embeddings[items]
        x = torch.cat((user_emb,item_emb), 1)
        x = self.nn(x)
        return x

In [18]:
latent_dim = 64
n_layers = 3 

EPOCHS = 5
BATCH_SIZE = 100
DECAY = 0.0001
LR = 0.005 
K = 2

In [19]:
lightgcn = LightGCN(
    latent_dim=latent_dim,
    num_layers=n_layers,
    max_index=max_node_id + 1
)
lightgcn = lightgcn.to(device)

In [20]:
def get_testset_loss(model, testset, loss_fn, embeddings):
    loss_list = []
    model.eval()
    with torch.no_grad():
        for items, users, ratings, weights in DataLoader(testset, batch_size=BATCH_SIZE):
            users, items, ratings, weights = users.to(device), items.to(device), ratings.to(device), weights.to(device)
            pred = model.pred(users, items, embeddings)
            loss = loss_fn(pred, ratings, weights)
            
            loss_list.append(loss.item())
            
    return sum(loss_list) / len(loss_list)


def train(model, optimizer, train_dataset, test_dataset, train_edge_index, loss_fn):
    loss_list_epoch = []
    valid_loss_list_epoch = []
    train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE)
    min_valid_loss = None
    min_loss_model = None
    for epoch in tqdm(range(EPOCHS)):
        n_batch = int(len(train_dataset)/BATCH_SIZE)
        loss_list = []
        model.train()
        for items, users, ratings, weights in tqdm(train_dataloader):
            optimizer.zero_grad()
            users, items, ratings, weights = users.to(device), items.to(device), ratings.to(device), weights.to(device)
            _, embeddings = model(train_edge_index)
            pred = model.pred(users, items, embeddings)
            loss = loss_fn(pred, ratings, weights)
            loss.backward()
            optimizer.step()
            loss_list.append(loss.item())
            
        # evaluate on validation data
        valid_loss = get_testset_loss(model, test_dataset, loss_fn, embeddings)
        if min_valid_loss is None or valid_loss < min_valid_loss:
            min_valid_loss = valid_loss
            min_loss_model = torch.save(model.state_dict(), f"epoch_{epoch}.ckpt")
            
        valid_loss_list_epoch.append(round(valid_loss, 4))
        loss_list_epoch.append(round(np.mean(loss_list),4))

    return loss_list_epoch, valid_loss_list_epoch

In [21]:
# Calculate weights of different labels and define weighted MSE loss
def weighted_MSE(preds, targets, weights):
    return (weights * (preds - targets) ** 2).mean()

loss_function = weighted_MSE
optimizer = torch.optim.Adam(lightgcn.parameters(), lr=LR)

In [22]:
loss_history, valid_loss_history = train(lightgcn, optimizer, train_dataset, test_dataset, train_edge_index, loss_function)

  0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/11663 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
epoch_list = [(i+1) for i in range(EPOCHS)]

plt.plot(epoch_list, loss_history, label='Training Loss')
plt.plot(epoch_list, valid_loss_history, label='Validation Loss')

plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()