In [41]:
import numpy as np
from collections import Counter
from tqdm import tqdm
from matplotlib import pyplot as plt
from sklearn.metrics import classification_report 
from create_datasets import *

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

class MyDataset(Dataset):
    def __init__(self, data, label):
        self.data = data
        self.label = label
    def __getitem__(self, index):
        return (torch.tensor(self.data[index], dtype=torch.float), torch.tensor(self.label[index], dtype=torch.long))
    def __len__(self):
        return len(self.data)
    

In [42]:
class CRNN(nn.Module):

    def __init__(self, in_channels, out_channels, n_len_seg, n_classes, device, verbose=False):
        super(CRNN, self).__init__()
        
        self.n_len_seg = n_len_seg
        self.n_classes = n_classes
        self.in_channels = in_channels
        self.out_channels = out_channels

        self.device = device
        self.verbose = verbose

        # (batch, channels, length)
        self.cnn = nn.Conv1d(in_channels=self.in_channels, 
                            out_channels=self.out_channels, 
                            kernel_size=16, 
                            stride=2)
        # (batch, seq, feature)
        self.rnn = nn.LSTM(input_size=(self.out_channels), 
                            hidden_size=self.out_channels, 
                            num_layers=1, 
                            batch_first=True, 
                            bidirectional=False)
        self.dense1 = nn.Linear(out_channels, 128)
        self.dense2 = nn.Linear(128, n_classes)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):

        self.n_channel, self.n_length = x.shape[-2], x.shape[-1]
        self.n_seg = self.n_length // self.n_len_seg
        out = x
        out = out.permute(0,2,1)
        out = out.view(-1, self.n_len_seg, self.n_channel)
        out = out.permute(0,2,1)
        out = self.cnn(out)
        out = out.mean(-1)
        out = out.view(-1, self.n_seg, self.out_channels)
        _, (out, _) = self.rnn(out)
        out = torch.squeeze(out, dim=0)
        out = self.dense1(out)
        out = self.dense2(out)
        out = self.softmax(out)
        return out

In [43]:
Net = CRNN(1,256,1024,256, torch.cuda.device, verbose=True)

In [44]:
class FullNet(nn.Module):
    def __init__(self, finger_print_model, graph_embedding_model, combined_model):
        super().__init__()
        self.FP_model = finger_print_model
        self.GE_model = graph_embedding_model
        self.CB_model = combined_model
    
    def forward(self, fp, ge):
        fp_out = self.FP_model(fp)
        ge_out = self.GE_model(ge)
        inp = torch.cat((fp_out, ge_out), 1)
        inp = inp.unsqueeze(1)
        out = self.CB_model(inp)
        return out

In [45]:
model = FullNet(CRNN(1,512,1024,256, 'cpu'), 
                CRNN(1,128,256,256, 'cpu'), 
                CRNN(1,256,512,2,'cpu'))

In [46]:
from prettytable import PrettyTable

def count_parameters(model):
    table = PrettyTable(["Modules", "Parameters"])
    total_params = 0
    for name, parameter in model.named_parameters():
        if not parameter.requires_grad: continue
        params = parameter.numel()
        table.add_row([name, params])
        total_params+=params
    print(table)
    print(f"Total Trainable Params: {total_params}")
    return total_params
    
count_parameters(model)

+---------------------------+------------+
|          Modules          | Parameters |
+---------------------------+------------+
|    FP_model.cnn.weight    |    8192    |
|     FP_model.cnn.bias     |    512     |
| FP_model.rnn.weight_ih_l0 |  1048576   |
| FP_model.rnn.weight_hh_l0 |  1048576   |
|  FP_model.rnn.bias_ih_l0  |    2048    |
|  FP_model.rnn.bias_hh_l0  |    2048    |
|   FP_model.dense1.weight  |   65536    |
|    FP_model.dense1.bias   |    128     |
|   FP_model.dense2.weight  |   32768    |
|    FP_model.dense2.bias   |    256     |
|    GE_model.cnn.weight    |    2048    |
|     GE_model.cnn.bias     |    128     |
| GE_model.rnn.weight_ih_l0 |   65536    |
| GE_model.rnn.weight_hh_l0 |   65536    |
|  GE_model.rnn.bias_ih_l0  |    512     |
|  GE_model.rnn.bias_hh_l0  |    512     |
|   GE_model.dense1.weight  |   16384    |
|    GE_model.dense1.bias   |    128     |
|   GE_model.dense2.weight  |   32768    |
|    GE_model.dense2.bias   |    256     |
|    CB_mod

2956290

In [47]:
final_edges = np.load('../datasets/final_edges.dump', allow_pickle=True)

In [None]:
data = generate_fingerprints(final_edges)

  0%|          | 0/87153 [00:00<?, ?it/s]

In [None]:
dataset = LinkDataset(data)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train , test = torch.utils.data.random_split(dataset, [train_size, test_size])
BATCH_SIZE = 256
trainloader = DataLoader(train, num_workers = 12, batch_size= BATCH_SIZE)
testloader = DataLoader(test, num_workers = 12, batch_size= BATCH_SIZE)

In [None]:
def eval(model,testloader):
    model.eval()
    test_loss = 0.0
    for fp, ge, label in testloader:
        output = model(fp.float(),ge.float())
        loss = criterion(output, label)
        test_loss+=loss.item()
#     print(test_loss / len(testloader))
    return test_loss / len(testloader)

In [None]:
train_losses = []
test_losses = []
num_epochs= 50
for epoch in tqdm(range(1, num_epochs)):
    train_loss = 0.0
    model.train()
    batch_id = 0
    for fp, ge, label in trainloader:
        batch_id +=1
        
        optimizer.zero_grad()
        output = model(fp.float(),ge.float())
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
    
        print(f'Epoch:{epoch} batch {batch_id}/{len(trainloader)} loss:{loss.item()}', end='\r')
    
    test_loss = eval(model, testloader)
    print("Train loss:",train_loss/len(trainloader),"Test loss :",test_loss)
    print()
    train_losses.append(train_loss/len(trainloader))
    test_losses.append(test_loss)