In [None]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions
from dataloader import get_loader

from sklearn.metrics import log_loss


In [None]:
# adj, weights, features, edge_features, new = load_data()

# adj = [normalize_adjacency(adj[i], weights[i]) for i in range(len(adj))]

ref = list()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            ref.append(True)
        else:
            ref.append(False)

ref = np.array(ref)
ref.shape

In [None]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'

version = "train"
# version = "valid"
n_labels = 4888

In [None]:
is_kept = []
with open('data/graph_labels.txt', "r") as f1:
    for line in f1:
        s1, s2 = line.strip().split(',')
        if len(s2.strip())>0:
            is_kept.append(version != "valid")
        else :
            is_kept.append(version == "valid")

In [None]:
## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# from google.colab import drive

# drive.mount('/content/gdrive/', force_remount=True)
# #../content/gdrive/MyDrive/dataltegrad
# %cd zetaFold

# SEQUENCES ONLY : Finetune DistillProtbert Model

In [None]:
from transformers import BertModel, BertTokenizer

PRE_TRAINED_MODEL_NAME = 'yarongef/DistilProtBert'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

In [None]:
## Define probert-based classifier
 
class ProteinClassifier(nn.Module):
    def __init__(self, n_classes):
        super(ProteinClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME).to(device)
        self.bert.eval()
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes).to(device)
        
    def forward(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        output = self.classifier(nn.ReLU()(output.last_hidden_state[:, 0, :]))
        return nn.LogSoftmax(dim=1)(output)

model = ProteinClassifier(18).to(device)

for module in model.bert.encoder.layer[0:-1]:
    for param in module.parameters():
        param.requires_grad = False

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 
    lr = 1,
    #weight_decay=0.01
)
epochs = 10
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-3) * ((epoch - epochs)/ epochs))

data_loader = get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=tokenizer, max_len=600, batch_size=8, shuffle=True, version=version, drop_last=True)

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_log_loss, epoch_accuracy = [], [], []
    
    for batch_num, e in enumerate(data_loader):
        optimizer.zero_grad()
        # Compute forward pass
        input = e['input_ids'].to(device)
        src_mask = e['attention_mask'].to(device)
        output = model(input, src_mask)
        
        # Compute loss
        output = output.view(-1, output.shape[-1])
        target = e['target'].reshape(-1).to(device)
        output = output.to(device)
        loss = criterion(output, target).to(device)

        # print('target', target, 'prediction', torch.argmax(output, dim=-1))

        # Backward
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        target = torch.nn.functional.one_hot(target, 18).to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()
        nll_loss = log_loss(target, output)
        epoch_log_loss.append(nll_loss)

        accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        # if batch_num > 2:
        #     break
    
    scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_log_loss = np.array(epoch_log_loss).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()

    pbar.set_description(
        f'Epoch {epoch}, cel: {round(epoch_loss, 4)}, '
        f'nll: {round(epoch_log_loss, 4)}, acc: {round(epoch_accuracy, 4)}'
        )
    # break

# GRAPHE : 

In [None]:
# adj, adj_weight, features, edge_features, new = load_data()

In [None]:
# adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
# adj_shapes = np.array([at.shape[0] for at in adj])
# adj = [adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))]

In [None]:
## Prepare graph features

# 1. Full graph-related database in arrays

adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = [adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))]

adj = np.array(adj)[np.array(is_kept)]
features = np.array(features, dtype=object)[np.array(is_kept)]
adj_shapes = adj_shapes[np.array(is_kept)]

# features = features[:n_labels] if version != "valid" else features[n_labels:]
# adj = adj[:n_labels] if version != "valid" else adj[n_labels:]

## 2. Then, we load batchs thanks to "indexs" key of dataloader element

## Example : loading only the first batch of graph-related data
# for e in data_loader:
#     break
# batch_indices = e['indexs'] ## Indexs that are part of the batch

# features_ = np.array(features)[batch_indices]
# features_ = np.vstack(features_)
# features_ = torch.FloatTensor(features_).to(device)

# adj_ = adj[batch_indices]
# adj_ = sp.block_diag(adj_)
# adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

# cum_nodes = adj_shapes[batch_indices]
# idx_batch = np.repeat(np.arange(
#     len(batch_indices)
#     ), cum_nodes)
# idx_batch = torch.LongTensor(idx_batch).to(device)

# # With GNN model

# model = GNN(...).to(device)
# output = model(features_, adj_, idx_batch)

In [None]:
class GNN(nn.Module):
    """
    Simple message passing model that consists of 2 message passing layers
    and the sum aggregation function
    """
    def __init__(self, input_dim, hidden_dim, dropout, n_class):
        super(GNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # self.fc2 = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, n_class)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_in, adj, idx):
        # first message passing layer
        x = self.fc1(x_in)
        x = self.relu(torch.mm(adj, x))
        x = self.dropout(x)

        # second message passing layer
        x = self.fc2(x)
        x = self.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.bn(out)

        # mlp to produce output
        out = self.relu(self.fc3(out))
        out = self.dropout(out)
        out = self.fc4(out)

        return F.log_softmax(out, dim=1)

In [None]:
model = GNN(86, 64, 0.1, 18).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
epochs = 50
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-3) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

data_loader = get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=tokenizer, max_len=600, batch_size=64, shuffle=False, version=version, drop_last=True)
                
pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_log_loss, epoch_accuracy = [], [], []

    for batch_num, e in enumerate(data_loader):
        
        batch_indices = e['indexs']
        optimizer.zero_grad()

        target = F.one_hot(e['target'], 18).float().to(device)

        # Compute graph forward
        features_ = np.array(features)[batch_indices]
        features_ = np.vstack(features_)
        features_ = torch.FloatTensor(features_).to(device)

        adj_ = adj[batch_indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        cum_nodes = adj_shapes[batch_indices]
        idx_batch = np.repeat(np.arange(
            len(batch_indices)
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        output = model(features_, adj_, idx_batch)

        # Backward
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        target = target.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()
        nll_loss = log_loss(target, output)
        epoch_log_loss.append(nll_loss)

        accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_log_loss = np.array(epoch_log_loss).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()

    pbar.set_description(
        f'Epoch {epoch}, cel: {round(epoch_loss, 4)}, '
        f'nll: {round(epoch_log_loss, 4)}, acc: {round(epoch_accuracy, 4)}'
        )

# MULTICULTURAL MODEL

In [None]:
class ConcatModel(nn.Module):
    def __init__(self, out_dim, feature_dim, graph_dim):
        super(ConcatModel, self).__init__()
        self.seqModel = ProteinClassifier(out_dim).to(device)
        self.graphModel = GNN(feature_dim, graph_dim, dropout=0.25, n_class=out_dim).to(device)

        self.classifier = nn.Linear(2*out_dim, 18)
        self.activation = nn.LogSoftmax(dim=1)

    def forward(self, x_in, adj, idx, input_ids, attention_mask):
        seqEmbedding = self.seqModel(input_ids, attention_mask)
        graphEmbedding = self.graphModel(x_in, adj, idx)
        
        out = torch.concat([graphEmbedding, seqEmbedding], dim=1)
        out = nn.Dropout(0.2)(out)
        out = self.classifier(out)
        
        return self.activation(out)

model = ConcatModel(64, 86, 128).to(device)

In [None]:
for module in model.seqModel.bert.encoder.layer[0:-1]:
    for param in module.parameters():
        param.requires_grad = False

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1)
epochs = 500
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-3) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

data_loader = get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=tokenizer, max_len=600, batch_size=8, shuffle=False, version=version, drop_last=True)
                
pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_log_loss, epoch_accuracy = [], [], []

    for batch_num, e in enumerate(data_loader):
        
        batch_indices = e['indexs']
        optimizer.zero_grad()

        target = F.one_hot(e['target'], 18).float().to(device)

        #Compute sequence forward
        input = e['input_ids'].to(device)
        src_mask = e['attention_mask'].to(device)

        # Compute graph forward
        features_ = np.array(features)[batch_indices]
        features_ = np.vstack(features_)
        features_ = torch.FloatTensor(features_).to(device)

        adj_ = adj[batch_indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        cum_nodes = adj_shapes[batch_indices]
        idx_batch = np.repeat(np.arange(
            len(batch_indices)
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        output = model(features_, adj_, idx_batch, input, src_mask)

        # Backward
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        print(epoch_loss[-1])
        
        target = target.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()
        nll_loss = log_loss(target, output)
        epoch_log_loss.append(nll_loss)

        accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

    scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_log_loss = np.array(epoch_log_loss).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()

    pbar.set_description(
        f'Epoch {epoch}, cel: {round(epoch_loss, 4)}, '
        f'nll: {round(epoch_log_loss, 4)}, acc: {round(epoch_accuracy, 4)}'
        )

In [None]:
#GTN version brouillon

from torch.nn.utils.rnn import pad_sequence
a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()



from graph_transformer_pytorch import GraphTransformer

model = GraphTransformer(
    dim = 256,
    depth = 6,
    edge_dim = 512,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True          # set to True if the nodes are ordered, default to False
)

nodes = torch.randn(1, 128, 256)
edges = torch.randn(1, 128, 128, 512)
mask = torch.ones(1, 128).bool()

nodes, edges = model(nodes, edges, mask = mask)

nodes.shape # (1, 128, 256) - project to R^3 for coordinates