In [2]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [3]:
NUM_USERS = 943
NUM_MOVIES = 1682
BATCH_SIZE = 128

In [4]:
class MetapathEncoder(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim, 
                ):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)

    def forward(self, metapath):
        x = torch.mean(metapath, dim=0)
        return self.fc(x)

class MAGNN(nn.Module):
    def __init__(self,
                #  user_features,
                #  movie_features,
                #  user_metapaths,
                #  movie_metapaths,
                 num_user_features,
                 num_movie_features,
                 hidden_dim,
                 out_dim,
                 batch_size,
                 device
                 ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.device = device

        # later take only train/test features and metapaths

        # self.user_features = user_features 
        # self.movie_features = movie_features

        # self.user_metapaths = user_metapaths
        # self.movie_metapaths = movie_metapaths

        # self.num_user_features = len(user_features[0])
        # self.num_movie_features = len(movie_features[0])

        self.num_user_features = num_user_features
        self.num_movie_features = num_movie_features

        self.user_feature_encoder = nn.Linear(self.num_user_features, hidden_dim)
        self.movie_feature_encoder = nn.Linear(self.num_movie_features, hidden_dim)

        self.metapath_encoder = MetapathEncoder(hidden_dim, hidden_dim)

        self.user_node_embedding = nn.Linear(hidden_dim, out_dim)
        self.movie_node_embedding = nn.Linear(hidden_dim, out_dim)

        self.recommender = nn.Linear(2 * out_dim, 1)

    def forward(self, edge):
        user_node, movie_node = edge[:, 0], edge[:, 1]
        batch_aggregated_user_metapath = []
        for i in range(self.batch_size):
            user_metapath_instances = self.user_metapaths[user_node[i].item()]
            user_aggregated_metapath = torch.zeros(self.hidden_dim).to(self.device)
            for metapath in user_metapath_instances:
                metapath_isntance = torch.vstack([
                    self.user_feature_encoder(self.user_features[metapath[0]]),
                    self.movie_feature_encoder(self.movie_features[metapath[1]]),
                    self.user_feature_encoder(self.user_features[metapath[2]])
                ])
                user_aggregated_metapath += self.metapath_encoder(metapath_isntance)
            batch_aggregated_user_metapath.append(user_aggregated_metapath)
        batch_aggregated_user_metapath = torch.vstack(batch_aggregated_user_metapath)

        batch_aggregated_movie_metapath = []
        for i in range(self.batch_size):
            movie_metapath_instances = self.movie_metapaths[movie_node[i].item()]
            movie_aggregated_metapath = torch.zeros(self.hidden_dim).to(self.device)
            for metapath in movie_metapath_instances:
                metapath_isntance = torch.vstack([
                    self.movie_feature_encoder(self.movie_features[metapath[0]]),
                    self.user_feature_encoder(self.user_features[metapath[1]]),
                    self.movie_feature_encoder(self.movie_features[metapath[2]]),
                ])
                movie_aggregated_metapath += self.metapath_encoder(metapath_isntance)
            batch_aggregated_movie_metapath.append(movie_aggregated_metapath)
        batch_aggregated_movie_metapath = torch.vstack(batch_aggregated_movie_metapath)

        user_embed = F.sigmoid(self.user_node_embedding(batch_aggregated_user_metapath))
        movie_embed = F.sigmoid(self.movie_node_embedding(batch_aggregated_movie_metapath))

        out = self.recommender(torch.cat([user_embed, movie_embed], dim=1))

        return F.sigmoid(out)

In [5]:
path = '../data/preprocessed/'

in_file = open(path + 'user_features.pickle', 'rb')
user_features = pickle.load(in_file)
for key, val in user_features.items():
    user_features[key] = torch.tensor(val.astype(np.float32))
in_file.close()

in_file = open(path + 'movie_features.pickle', 'rb')
movie_features = pickle.load(in_file)
for key, val in movie_features.items():
    movie_features[key] = torch.tensor(val.astype(np.float32))
in_file.close()

in_file = open(path + 'u_m_u_metapath_map.pickle', 'rb')
user_metapaths = pickle.load(in_file)
in_file.close()

in_file = open(path + 'm_u_m_metapath_map.pickle', 'rb')
movie_metapaths = pickle.load(in_file)
in_file.close()

train_pos = np.load(path + 'train_pos.npy')
test_pos = np.load(path + 'test_pos.npy')

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

for key, val in user_features.items():
    user_features[key] = val.to(device)

for key, val in movie_features.items():
    movie_features[key] = val.to(device)

In [7]:
class UserMovieDataset(Dataset):
    def __init__(self, positives):

        self.pos = positives
        self.sample_negatives()

        self.data = torch.tensor(np.vstack([self.pos, self.neg]))
        self.labels = torch.vstack([torch.ones((len(self.pos), 1)), torch.zeros((len(self.neg), 1))])

    def sample_negatives(self):
        self.neg = []
        for i in tqdm(range(len(self.pos))):
            user_id = np.random.randint(NUM_USERS)
            movie_id = np.random.randint(NUM_MOVIES)
            while [user_id, movie_id] not in self.pos and \
                    [user_id, movie_id] not in self.neg:
                user_id = np.random.randint(NUM_USERS)
                movie_id = np.random.randint(NUM_MOVIES)
            self.neg.append([user_id, movie_id])
        self.neg = np.array(self.neg)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

    def __len__(self):
        return len(self.data)

In [8]:
train_dataset = UserMovieDataset(train_pos)
test_dataset = UserMovieDataset(test_pos)

100%|██████████| 49906/49906 [00:16<00:00, 3002.41it/s]
100%|██████████| 5469/5469 [00:00<00:00, 15404.54it/s]


In [9]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [10]:
model = MAGNN(user_features, movie_features, user_metapaths, movie_metapaths, 128, 128, BATCH_SIZE, device)

In [11]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [1]:
def train(epoch):
    losses = []
    accs = []

    model.train()
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for batch in progress_bar:
        data, label = batch
        data, label = data.to(device), label.to(device)

        optimizer.zero_grad()

        pred = model(data)

        loss = criterion(label, pred)
        acc = ((pred > 0.5) == label).sum() / BATCH_SIZE
        
        losses.append(loss.item())
        accs.append(acc.cpu())

        loss.backward()
        optimizer.step()

        progress_bar.set_description(f'Loss: {np.mean(losses):.5f}, Acc: {np.mean(accs):.5f}')

In [12]:
EPOCHS = 10
model.to(device)

for epoch in range(EPOCHS):
    train(epoch)

Epoch 0:   0%|          | 0/780 [00:00<?, ?it/s]

Loss: 52.54102, Acc: 0.30469:   0%|          | 1/780 [32:42<424:33:33, 1962.02s/it]