In [1]:
# !conda install torch_sparse, torch_geometric

In [10]:
from two_class_data_generation import *

import os
import random
import numpy as np
import pandas as pd
import torch
import time
import itertools
import torchvision
from torchvision import datasets, transforms
from torch import nn, optim
from torch.utils.data import Dataset

import torch_geometric
from torch_geometric.data import DataLoader
from torch_sparse import spmm

In [11]:
torch.manual_seed(10)
random.seed(10)
np.random.seed(10)

device_id = 0
print(torch.cuda.get_device_name(device_id))
torch.cuda.set_device(device_id)

NVIDIA Tesla V100-SXM2-16GB


In [12]:
def get_chord_indices_assym(n_vec, n_link):
    """
    Generates the position indicies, based on the asymmetric Chord protocol (incl. itself).

    :param n_vec: number of vectors (i.e. length of a sequence)
    :param n_link: number of links in the Chord protocol
    :return: target indices in two lists, each is of size n_vec * (n_link + 1)
    """
    
    rows = list(
        itertools.chain(
            *[
                [i for j in range(n_link + 1)] for i in range(n_vec)
            ]
        )
    )
    
    cols = list(
        itertools.chain(
            *[
                [i] + [(i + 2 ** k) % n_vec for k in range(n_link)] for i in range(n_vec)
            ]
        )
    )
    
    return rows, cols

In [19]:
# SMF with sparse multiplication

class VIdenticalModule(nn.Module):
    def __init__(self):
        super(VIdenticalModule, self).__init__()

    def forward(self, data):
        return data

    
class WModuleSparse(nn.Module):
    def __init__(self, n_link, n_dim, n_hidden):
        super(WModuleSparse, self).__init__()
        self.network = nn.Sequential(
            nn.Linear(n_dim, n_hidden),
            nn.GELU(),
            nn.Linear(n_hidden, n_link + 1)
        )

    def forward(self, data):
        return self.network(data)
    
    
class InteractionModuleSparse(nn.Module):
    def __init__(self, n_class, n_W, n_vec, n_dim, n_link,
                 n_hidden_f, n_hidden_v, batch_size, use_cuda):
        super(InteractionModuleSparse, self).__init__()
        self.n_vec = n_vec
        self.n_dim = n_dim
        self.n_link = n_link
        self.batch_size = batch_size
        self.use_cuda = use_cuda
        self.fs = nn.ModuleList(
            [WModuleSparse(n_link, n_dim, n_hidden_f) for i in range(n_W)]
        )
        self.g = VIdenticalModule()
        self.final = nn.Linear(self.n_vec * self.n_dim, n_class, bias=True)
        self.chord_indicies = torch.tensor(get_chord_indices_assym(n_vec, n_link))
        if self.use_cuda:
            self.chord_indicies = self.chord_indicies.cuda()
    
    def forward(self, data):
#         if self.cuda:
#             data = data.cuda()
        V = self.g(data)
        residual = V
        for f in self.fs[::-1]:
            W = f(data)
            
            V = spmm(
                self.chord_indicies,
                W.reshape(W.size(0), W.size(1) * W.size(2)), 
                self.n_vec,
                self.n_vec,
                V
            )
            V += residual
            
        V = self.final(V.view(data.size(0), -1))
        return V 

In [20]:
# Supplementary functions 

def weights_init(module):
    classname = module.__class__.__name__
    if classname.find('Linear') != -1:
        torch.nn.init.normal_(module.weight, 0.0, 1e-2)
        if hasattr(module, 'bias') and module.bias is not None:
            torch.nn.init.normal_(module.bias, 0.0, 1e-2)
            
            
class DatasetCreator(Dataset):
    """
    Class to construct a dataset for training/inference
    """

    def __init__(self, mode, data, labels):
        self.data = data
        self.labels = labels
        assert len(self.data) == len(self.labels),\
            "The number of samples doesn't match the number of labels"

    def __getitem__(self, index):
        """
        Returns: tuple (sample, target)
        """
        X = data[index].unsqueeze(-1)
        Y = labels[index].type(torch.LongTensor)
        return (X, Y)

    def __len__(self):
        return len(self.labels)

In [21]:
# Initialize the model 

n_data = 12000
n_test = 3000
n_class = 2
n_dim = 1
n_hidden_f = 20
n_hidden_v = 3

batch_size = 20

# For hard testing
# n_W = 20
# n_link = 20
# n_vec = 1048576 # 15215MiB on GPU (batch_size=2)

# For easy testing
n_W = 7
n_link = 7
n_vec = 128


net = InteractionModuleSparse(
    n_class,
    n_W,
    n_vec,
    n_dim,
    n_link,
    n_hidden_f,
    n_hidden_v,
    batch_size,
    use_cuda=True
)

net.apply(weights_init)

net = net.cuda()

In [22]:
# Preparing synthetic data and DataLoaders from torch_geometric

data, labels = generate_two_class_data(n_data, n_vec, binary=False, same_sigma=False, xor=True)

data, labels, data_val, labels_val = data[n_test:], labels[n_test:], data[:n_test], labels[:n_test]

trainset = DatasetCreator(
    mode='train',
    data = data,
    labels = labels
)
trainloader = torch_geometric.data.DataLoader(
    trainset,
    batch_size=batch_size,
    shuffle=True
)


valset = DatasetCreator(
    mode='test',
    data = data_val,
    labels = labels_val
)
valloader = torch_geometric.data.DataLoader(
    valset,
    batch_size=batch_size,
    shuffle=True
)

In [23]:
# Training loop
def TrainSMF(
        net,
        trainloader,
        valloader,
        n_epochs,
        test_freq,
        optimizer,
        loss
):
    losses = []
    losses_eval = []
    accuracies = []
    for epoch in range(n_epochs):
        # Training
        running_loss = 0
        for i, (X, Y) in enumerate(trainloader):
            X = X.cuda()
            Y = Y.cuda()
            optimizer.zero_grad()
            pred = net(X)
            output = loss(pred, Y)
            output.backward()
            optimizer.step()
            running_loss += output.item()

        print("Epoch {} - Training loss:   {}".format(epoch, running_loss / len(trainloader)))
        losses.append(float(running_loss / len(trainloader)))
        
        # Validation
        if epoch % test_freq == 0:
            net.eval()
            with torch.no_grad():
                correct = 0
                total = 0
                val_loss = 0.0
                for i, (X, Y) in enumerate(valloader):
                    X = X.cuda()
                    Y = Y.cuda()
                    pred = net(X)
                    val_loss += loss(pred, Y).item()
                    
                    _, predicted = torch.max(pred.data, 1)
                    total += Y.size(0)
                    correct += (predicted == Y).sum().item()
                    
            print("Epoch {} - Validation loss: {}".format(epoch, val_loss / len(valloader)))
            print('Accuracy of the network: %d %%' % (100 * correct / total))
            print('_' * 40)
            losses_eval.append(float(val_loss / len(valloader)))
            accuracies.append(100 * correct / total)
            net.train()

In [24]:
loss = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=1e-3)

TrainSMF(
    net=net,
    trainloader=trainloader,
    valloader=valloader,
    n_epochs=10,
    test_freq=1,
    optimizer=optimizer,
    loss=loss
)

Epoch 0 - Training loss:   0.6541194513108995
Epoch 0 - Validation loss: 1.0018336772918701
Accuracy of the network: 0 %
________________________________________
Epoch 1 - Training loss:   0.6374050671524472
Epoch 1 - Validation loss: 1.076012372970581
Accuracy of the network: 0 %
________________________________________
Epoch 2 - Training loss:   0.6366049136718114
Epoch 2 - Validation loss: 1.093941330909729
Accuracy of the network: 0 %
________________________________________
Epoch 3 - Training loss:   0.636584500140614
Epoch 3 - Validation loss: 1.1082909107208252
Accuracy of the network: 0 %
________________________________________
Epoch 4 - Training loss:   0.6365858999225829
Epoch 4 - Validation loss: 1.0924708843231201
Accuracy of the network: 0 %
________________________________________
Epoch 5 - Training loss:   0.6366091051366594
Epoch 5 - Validation loss: 1.10152268409729
Accuracy of the network: 0 %
________________________________________
Epoch 6 - Training loss:   0.6366