In [None]:
# ## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# # from google.colab import drive

# # drive.mount('/content/gdrive/', force_remount=True)
# # ../gdrive/MyDrive/data
# %cd zetaFold

In [None]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions, graph_train_valid_test

from sklearn.metrics import log_loss

# from models.enhanced_graph import GTN
from models.self_attention_graph import GTN
from models.baseline_graph import GNN
from models.sequence_protbert import ProteinClassifier

In [None]:
n_labels = 4888
treshold_valid = 733

ref = dict()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) != 0:
            if len([_ for _ in ref.values() if _ == "train"]) < n_labels - treshold_valid:
              ref[i] = "train"
            else :
              ref[i] = "valid"
        else:
            ref[i] = "test"

ref_train = np.array([i for i in range(len(ref)) if ref[i]=="train"])
ref_valid = np.array([i for i in range(len(ref)) if ref[i]=="valid"])
ref_test = np.array([i for i in range(len(ref)) if ref[i]=="test"])

In [None]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'

# Data loading & preprocessing

In [None]:
# labels

y = []
valid_id = []

with open('data/graph_labels.txt', "r") as f1:
  for line in f1:
    s1, s2 = line.strip().split(',')
    if len(s2.strip())>0:
      y.append(int(s2))
    else :
      valid_id.append(s1)

y = np.array(y)
y_train, y_valid = y[:-treshold_valid], y[-treshold_valid:]

In [None]:
# graph features

adj, adj_weight, features = load_data()

(
    adj_train, adj_valid, adj_test, 
    features_train, features_valid, features_test, 
    adj_shapes_train, adj_shapes_valid, adj_shapes_test
                                                        ) = graph_train_valid_test(adj, adj_weight, features, ref_train, ref_valid, ref_test)

In [None]:
# sequences

sequences = [] 
with open('data/sequences.txt', "r") as f1:
  for line in f1:
    sequences.append(' '.join(list(line[:-1])))

sequences = np.array(sequences)
sequences_train, sequences_valid, sequences_test = sequences[ref_train], sequences[ref_valid], sequences[ref_test]

from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('yarongef/DistilProtBert', do_lower_case=False)

# Model

In [None]:
from models.sequence_protbert import ProteinClassifier
from transformers import BertModel

class MixedModel(nn.Module):
    def __init__(self, dropout=0.8, device="cuda"):
        super(MixedModel, self).__init__()
      
        sequences_model = ProteinClassifier(18).to(device)
        #load weights
        self.bert = sequences_model.bert
        
        self.graph_model = GNN(90, 64, 0.2, 18).to(device)
        #load weights
        
        self.mlp = nn.Linear(self.bert.config.hidden_size + 64, 18).to(device)
        self.dropout = dropout
        
    def forward(self, x_in, adj, idx, input_ids, attention_mask):

        # first message passing layer
        x = self.graph_model.fc1(x_in)
        x = self.graph_model.relu(torch.mm(adj, x))

        # second message passing layer
        x = self.graph_model.fc2(x)
        x = self.graph_model.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.graph_model.bn(out)

        # mlp to produce embedding
        out = self.graph_model.relu(self.graph_model.fc3(out))

        output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        output = output.last_hidden_state[:, 0, :] #get cls token embedding
        output = nn.Dropout(self.dropout)(output)

        joined = torch.cat([output, out], dim=1)
        joined = self.mlp(joined)
        
        return F.log_softmax(joined, dim=1)

In [None]:
model = MixedModel(device=device)

for module in model.bert.encoder.layer[:]:
    for param in module.parameters():
        param.requires_grad = False

# Train

In [None]:
pops = np.array([440.,  50., 939.,  60., 112., 625., 202.,  74., 998.,  57.,  43.,305.,  44.,  59., 548., 226.,  60.,  46.])
weights = 1/pops
weights = weights/np.mean(weights)
weights = torch.tensor(weights, dtype=torch.float32).to(device)
balanced_criterion = nn.CrossEntropyLoss(weight=weights)

In [None]:
encode = lambda s : tokenizer.encode_plus(
            s,
            truncation=True,
            add_special_tokens=True,
            max_length=989,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

In [None]:
def process_batch_data(features, adj_shapes, adj, sequences, indices, batchSize, incr_i, device):
    features_ = np.array(features)[indices]
    features_ = np.vstack(features_)
    features_ = torch.FloatTensor(features_).to(device)

    cum_nodes = adj_shapes[indices]
    idx_batch = np.repeat(np.arange(
        batchSize if incr_i%batchSize==0 else incr_i%batchSize
        ), cum_nodes)
    idx_batch = torch.LongTensor(idx_batch).to(device)

    adj_ = adj[indices]
    adj_ = sp.block_diag(adj_)
    adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences[indices]]).to(device)

    attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences[indices]]).to(device)

    return features_, adj_, idx_batch, cum_nodes, sequences_, attention_mask_

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1,
    # lr=1e-3,
    # weight_decay=1e-3
    )
epochs = 20
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                lr_lambda=lambda epoch: 1e-3 + (1e-3 - 1e-3) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()
# criterion = balanced_criterion

batchSize = 2 #training batchsize

patience = 10
patience_count = 0

best_valid_loss = 300
previous_epoch_loss_valid = 300

res = list()
shuffle_train = np.arange(len(ref_train))
shuffle_valid = np.arange(len(ref_valid))

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []
    epoch_nll_train, epoch_nll_valid = [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)

    # Train
    model.train()
    for i in range(0, len(ref_train), batchSize):
        incr_i = min(i+batchSize, len(ref_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        features_, adj_, idx_batch, cum_nodes, sequences_, attention_mask_ = process_batch_data(
            features_train, adj_shapes_train, adj_train, sequences_train, indices, batchSize, incr_i, device
            )

        optimizer.zero_grad()
        output = model(features_, adj_, idx_batch, sequences_, attention_mask_)#, cum_nodes)

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        epoch_nll_train.append(nn.CrossEntropyLoss()(output, targets_).item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

    # Validation
    batchSizeVal=64
    model.eval()
    for i in range(0, len(ref_valid), int(batchSizeVal)):
        incr_i = min(i+int(batchSizeVal), len(ref_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        features_, adj_, idx_batch, cum_nodes, sequences_, attention_mask_ = process_batch_data(
            features_valid, adj_shapes_valid, adj_valid, sequences_valid, indices, batchSizeVal, incr_i, device
            )

        output = model(features_, adj_, idx_batch, sequences_, attention_mask_)#, cum_nodes)

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        epoch_nll_valid.append(nn.CrossEntropyLoss()(output, targets_).item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()
    epoch_nll_train = np.array(epoch_nll_train).mean()
    epoch_nll_valid = np.array(epoch_nll_valid).mean()

    if epoch_nll_valid < best_valid_loss :
        torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        best_valid_loss = epoch_loss_valid 

    # tqdm.write(
    pbar.set_description(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}\n'
        f"nll train: {round(epoch_nll_train, 4)}, nll valid: {round(epoch_nll_valid, 4)}\n"
        f"best valid: {round(best_valid_loss, 4)}"
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        res = list()
        shuffle_test = np.arange(len(ref_test))

        for i in range(0, len(ref_test), int(batchSizeVal)):
            model.eval()

            incr_i = min(i+int(batchSizeVal), len(ref_test))
            indices = shuffle_test[i: incr_i]

            features_, adj_, idx_batch, cum_nodes, sequences_, attention_mask_ = process_batch_data(
                features_test, adj_shapes_test, adj_test, sequences_test, indices, batchSizeVal, incr_i, device
                )
            
            output = model(features_, adj_, idx_batch, sequences_, attention_mask_)#, cum_nodes)

            res.append(output.to('cpu').detach().numpy())