In [3]:
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import pickle
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

In [184]:
NUM_USERS = 943
NUM_MOVIES = 1682
BATCH_SIZE = 128

num_user_features = 24
num_movie_features = 19

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [172]:
path = '../data/preprocessed/'

in_file = open(path + 'user_features.pickle', 'rb')
user_features = pickle.load(in_file)
for key, val in user_features.items():
    user_features[key] = torch.tensor(val.astype(np.float32))
in_file.close()

in_file = open(path + 'movie_features.pickle', 'rb')
movie_features = pickle.load(in_file)
for key, val in movie_features.items():
    movie_features[key] = torch.tensor(val.astype(np.float32))
in_file.close()

in_file = open(path + 'u_m_u_metapath_map.pickle', 'rb')
user_metapaths = pickle.load(in_file)
in_file.close()

in_file = open(path + 'm_u_m_metapath_map.pickle', 'rb')
movie_metapaths = pickle.load(in_file)
in_file.close()

train_pos = np.load(path + 'train_pos.npy')
test_pos = np.load(path + 'test_pos.npy')

In [173]:
class UserMovieDataset(Dataset):
    def __init__(self, positives, user_metapaths, movie_metapaths):

        self.max_metapath_size = 7000
        self.user_metapaths = user_metapaths
        self.movie_metapaths = movie_metapaths
        self.merge_metapaths()

        self.pos = positives
        self.sample_negatives()

        self.data = torch.tensor(np.vstack([self.pos, self.neg]))
        self.labels = torch.vstack([torch.ones((len(self.pos), 1)), torch.zeros((len(self.neg), 1))])

    def sample_negatives(self):
        self.neg = []
        for i in tqdm(range(len(self.pos)), desc='Sampling negative edges'):
            user_id = np.random.randint(NUM_USERS)
            movie_id = np.random.randint(NUM_MOVIES)
            while ([user_id, movie_id] not in self.pos and \
                    [user_id, movie_id] not in self.neg) or \
                    len(self.user_metapaths[user_id]) == 0 or \
                    len(self.movie_metapaths[movie_id]) == 0:
                user_id = np.random.randint(NUM_USERS)
                movie_id = np.random.randint(NUM_MOVIES)
            self.neg.append([user_id, movie_id])
        self.neg = np.array(self.neg)

    def merge_metapaths(self):
        for key, val in tqdm(self.user_metapaths.items(), desc='Extracting user metapaths'):
            if len(self.user_metapaths[key]) > 0:
                self.user_metapaths[key] = [torch.vstack([user_features[i] for i in val[:, 0]])[:self.max_metapath_size], 
                                            torch.vstack([movie_features[i] for i in val[:, 1]])[:self.max_metapath_size], 
                                            torch.vstack([user_features[i] for i in val[:, 2]])[:self.max_metapath_size]]
                while len(self.user_metapaths[key][0]) < self.max_metapath_size:
                    self.user_metapaths[key][0] = torch.vstack([self.user_metapaths[key][0], torch.zeros_like(self.user_metapaths[key][0][0])])
                    self.user_metapaths[key][1] = torch.vstack([self.user_metapaths[key][1], torch.zeros_like(self.user_metapaths[key][1][0])])
                    self.user_metapaths[key][2] = torch.vstack([self.user_metapaths[key][2], torch.zeros_like(self.user_metapaths[key][2][0])])
                
            
        for key, val in tqdm(self.movie_metapaths.items(), desc='Extracting movie metapaths'):
            if len(self.movie_metapaths[key]) > 0:
                self.movie_metapaths[key] = [torch.vstack([movie_features[i] for i in val[:, 0]])[:self.max_metapath_size],
                                            torch.vstack([user_features[i] for i in val[:, 1]])[:self.max_metapath_size],
                                            torch.vstack([movie_features[i] for i in val[:, 2]])[:self.max_metapath_size]]
                while len(self.movie_metapaths[key][0]) < self.max_metapath_size:
                    self.movie_metapaths[key][0] = torch.vstack([self.movie_metapaths[key][0], torch.zeros_like(self.movie_metapaths[key][0][0])])
                    self.movie_metapaths[key][1] = torch.vstack([self.movie_metapaths[key][1], torch.zeros_like(self.movie_metapaths[key][1][0])])
                    self.movie_metapaths[key][2] = torch.vstack([self.movie_metapaths[key][2], torch.zeros_like(self.movie_metapaths[key][2][0])])

    def __getitem__(self, idx):
        user_id, movie_id = self.data[idx]
        return self.data[idx], self.labels[idx], *self.user_metapaths[user_id.item()], *self.movie_metapaths[movie_id.item()]

    def __len__(self):
        return len(self.data)

In [174]:
train_dataset = UserMovieDataset(train_pos, user_metapaths, movie_metapaths)
# test_dataset = UserMovieDataset(test_pos)

Extracting user metapaths: 100%|██████████| 943/943 [11:24<00:00,  1.38it/s]
Extracting movie metapaths: 100%|██████████| 1682/1682 [57:37<00:00,  2.06s/it] 
Sampling negative edges: 100%|██████████| 49906/49906 [00:26<00:00, 1886.91it/s]


In [178]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
# test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

In [221]:
class MetapathEncoder(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim, 
                ):
        super().__init__()
        self.fc = nn.Linear(in_dim, out_dim)

    def forward(self, metapath):
        # print(f'{metapath.shape=}')
        x = torch.mean(metapath, dim=1)
        # print(f'{x.shape=}')
        return self.fc(x)

class MAGNN(nn.Module):
    def __init__(self,
                #  user_features,
                #  movie_features,
                #  user_metapaths,
                #  movie_metapaths,
                 num_user_features,
                 num_movie_features,
                 hidden_dim,
                 out_dim,
                 batch_size,
                 device
                 ):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.batch_size = batch_size
        self.device = device

        # later take only train/test features and metapaths

        # self.user_features = user_features 
        # self.movie_features = movie_features

        # self.user_metapaths = user_metapaths
        # self.movie_metapaths = movie_metapaths

        # self.num_user_features = len(user_features[0])
        # self.num_movie_features = len(movie_features[0])

        self.num_user_features = num_user_features
        self.num_movie_features = num_movie_features

        self.user_feature_encoder = nn.Linear(self.num_user_features, hidden_dim)
        self.movie_feature_encoder = nn.Linear(self.num_movie_features, hidden_dim)

        self.metapath_encoder = MetapathEncoder(hidden_dim, hidden_dim)

        self.user_node_embedding = nn.Linear(hidden_dim, out_dim)
        self.movie_node_embedding = nn.Linear(hidden_dim, out_dim)

        self.recommender = nn.Linear(2 * out_dim, 1)

    def forward(self, edge, user_metapaths1, user_metapaths2, user_metapaths3, movie_metapaths1, movie_metapaths2, movie_metapaths3):
        # user_node, movie_node = edge[:, 0], edge[:, 1]
        # batch_aggregated_user_metapath = []
        # for i in range(self.batch_size):
        #     user_metapath_instances = user_metapaths[user_node[i].item()]
        #     user_aggregated_metapath = torch.zeros(self.hidden_dim).to(self.device)
        #     for metapath in user_metapath_instances:
        #         metapath_isntance = torch.vstack([ # optimize
        #             self.user_feature_encoder(self.user_features[metapath[0]]),
        #             self.movie_feature_encoder(self.movie_features[metapath[1]]),
        #             self.user_feature_encoder(self.user_features[metapath[2]])
        #         ])
        #         user_aggregated_metapath += self.metapath_encoder(metapath_isntance)
        #     batch_aggregated_user_metapath.append(user_aggregated_metapath)
        user_metapath_isntance = torch.cat([
            self.user_feature_encoder(user_metapaths1),
            self.movie_feature_encoder(user_metapaths2),
            self.user_feature_encoder(user_metapaths3)
        ], dim=1)
        print(f'{user_metapath_isntance.shape=}')
        user_aggregated_metapath = self.metapath_encoder(user_metapath_isntance)
        print(f'{user_aggregated_metapath.shape=}')
        # batch_aggregated_movie_metapath = []
        # for i in range(self.batch_size):
        #     movie_metapath_instances = movie_metapaths[movie_node[i].item()]
        #     movie_aggregated_metapath = torch.zeros(self.hidden_dim).to(self.device)
        #     for metapath in movie_metapath_instances:
        #         metapath_isntance = torch.vstack([
        #             self.movie_feature_encoder(self.movie_features[metapath[0]]),
        #             self.user_feature_encoder(self.user_features[metapath[1]]),
        #             self.movie_feature_encoder(self.movie_features[metapath[2]]),
        #         ])
        #         movie_aggregated_metapath += self.metapath_encoder(metapath_isntance)
        #     batch_aggregated_movie_metapath.append(movie_aggregated_metapath)
        # batch_aggregated_movie_metapath = torch.vstack(batch_aggregated_movie_metapath)

        movie_metapath_isntance = torch.cat([
            self.movie_feature_encoder(movie_metapaths1),
            self.user_feature_encoder(movie_metapaths2),
            self.movie_feature_encoder(movie_metapaths3)
        ], dim=1)
        print(f'{movie_metapath_isntance.shape=}')
        movie_aggregated_metapath = self.metapath_encoder(movie_metapath_isntance)
        print(f'{movie_aggregated_metapath.shape=}')

        user_embed = F.sigmoid(self.user_node_embedding(user_aggregated_metapath))
        movie_embed = F.sigmoid(self.movie_node_embedding(movie_aggregated_metapath))

        out = self.recommender(torch.cat([user_embed, movie_embed], dim=1))

        return F.sigmoid(out)

In [226]:
model = MAGNN(num_user_features, num_movie_features, 128, 128, BATCH_SIZE, device)

In [227]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [230]:
def train(epoch):
    losses = []
    accs = []

    model.train()
    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch}')
    for batch in progress_bar:
        edge, label, umetapath1, umetapath2, umetapath3, mmetapath1, mmetapath2, mmetapath3 = batch
        edge, label, umetapath1, umetapath2, umetapath3, mmetapath1, mmetapath2, mmetapath3 = edge.to(device), label.to(device), umetapath1.to(device), umetapath2.to(device), umetapath3.to(device), mmetapath1.to(device), mmetapath2.to(device), mmetapath3.to(device)

        optimizer.zero_grad()

        pred = model(edge, umetapath1, umetapath2, umetapath3, mmetapath1, mmetapath2, mmetapath3)

        loss = criterion(label, pred)
        acc = ((pred > 0.5) == label).sum() / BATCH_SIZE
        
        losses.append(loss.item())
        accs.append(acc.cpu())

        loss.backward()
        optimizer.step()

        progress_bar.set_description(f'Loss: {np.mean(losses):.5f}, Acc: {np.mean(accs):.5f}')

In [231]:
EPOCHS = 10
model.to(device)

for epoch in range(EPOCHS):
    train(epoch)

Epoch 0:   0%|          | 0/780 [00:09<?, ?it/s]


OutOfMemoryError: CUDA out of memory. Tried to allocate 1.28 GiB (GPU 0; 6.00 GiB total capacity; 1.72 GiB already allocated; 2.67 GiB free; 1.74 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF