In [1]:
# ## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# # from google.colab import drive

# # drive.mount('/content/gdrive/', force_remount=True)
# # ../gdrive/MyDrive/data
# %cd zetaFold

In [1]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions, graph_train_valid_test
from dataloader import get_loader

from sklearn.metrics import log_loss

# from models.enhanced_graph import GTN
from models.self_attention_graph import GTN
from models.baseline_graph import GNN

In [2]:
n_labels = 4888
treshold_valid = 733

ref = dict()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) != 0:
            if len([_ for _ in ref.values() if _ == "train"]) < n_labels - treshold_valid:
              ref[i] = "train"
            else :
              ref[i] = "valid"
        else:
            ref[i] = "test"

ref_train = np.array([i for i in range(len(ref)) if ref[i]=="train"])
ref_valid = np.array([i for i in range(len(ref)) if ref[i]=="valid"])
ref_test = np.array([i for i in range(len(ref)) if ref[i]=="test"])

In [3]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'

# Data loading & preprocessing

In [4]:
# labels

y = []
valid_id = []

with open('data/graph_labels.txt', "r") as f1:
  for line in f1:
    s1, s2 = line.strip().split(',')
    if len(s2.strip())>0:
      y.append(int(s2))
    else :
      valid_id.append(s1)

y = np.array(y)
y_train, y_valid = y[:-treshold_valid], y[-treshold_valid:]

In [5]:
# graph features

adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

(
    adj_train, adj_valid, adj_test, 
    features_train, features_valid, features_test, 
    adj_shapes_train, adj_shapes_valid, adj_shapes_test
                                                        ) = graph_train_valid_test(adj, adj_weight, features, ref_train, ref_valid, ref_test)

# Essai avec GTN

In [6]:
model = GTN(86, 64, 0.2, 18).to(device)

In [8]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    # lr=1,
    lr=1e-3,
    # weight_decay=0.0001
    )
epochs = 50
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-4) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

shuffle_train = np.arange(len(ref_train))
shuffle_valid = np.arange(len(ref_valid))

batchSize = 64

patience = 50
patience_count = 0

previous_epoch_loss_valid = 300

res = list()

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)


    # Train
    # batchSize=2
    model.train()
    for i in tqdm(range(0, len(ref_train), batchSize)):

        incr_i = min(i+batchSize, len(ref_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        features_ = np.array(features_train)[indices]
        features_ = np.vstack(features_)
        features_ = torch.FloatTensor(features_).to(device)

        cum_nodes = adj_shapes_train[indices]
        idx_batch = np.repeat(np.arange(
            batchSize if incr_i%batchSize==0 else incr_i%batchSize
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        adj_ = adj_train[indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        optimizer.zero_grad()

        output = model(features_, adj_, idx_batch, cum_nodes)
        
        del features_
        del adj_
        del idx_batch

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        del output

    # Validation
    batchSize=64
    model.eval()
    for i in range(0, len(ref_valid), int(batchSize)):

        incr_i = min(i+int(batchSize), len(ref_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        features_ = np.array(features_train)[indices]
        features_ = np.vstack(features_)
        features_ = torch.FloatTensor(features_).to(device)

        cum_nodes = adj_shapes_valid[indices]
        idx_batch = np.repeat(np.arange(
            int(batchSize) if incr_i%int(batchSize)==0 else incr_i%int(batchSize)
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        adj_ = adj_valid[indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        output = model(features_, adj_, idx_batch, cum_nodes)

        del features_
        del adj_
        del idx_batch

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

        del output

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    tqdm.write(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}'
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        if patience_count == 0:
          torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        del res
        res = list()
        shuffle_test = np.arange(len(ref_test))

        for i in range(0, len(ref_test), int(batchSize)):
            model.eval()

            incr_i = min(i+int(batchSize), len(ref_test))
            indices = shuffle_test[i: incr_i]

            features_ = np.array(features_test)[indices]
            features_ = np.vstack(features_)
            features_ = torch.FloatTensor(features_).to(device)

            cum_nodes = adj_shapes_test[indices]
            idx_batch = np.repeat(np.arange(
                batchSize if incr_i%batchSize==0 else incr_i%batchSize
                ), cum_nodes)
            idx_batch = torch.LongTensor(idx_batch).to(device)
            
            adj_ = adj_test[indices]
            adj_ = sp.block_diag(adj_)
            adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)
            
            output = model(features_, adj_, idx_batch, cum_nodes)

            res.append(output.to('cpu').detach().numpy())

            del output

  0%|          | 0/50 [00:00<?, ?it/s]