In [None]:
# ## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# # from google.colab import drive

# # drive.mount('/content/gdrive/', force_remount=True)
# # ../gdrive/MyDrive/data
# %cd zetaFold

In [None]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions, graph_train_valid_test

from sklearn.metrics import log_loss

# from models.enhanced_graph import GTN
from models.self_attention_graph import GTN
from models.baseline_graph import GNN

In [None]:
n_labels = 4888
treshold_valid = 733

ref = dict()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) != 0:
            if len([_ for _ in ref.values() if _ == "train"]) < n_labels - treshold_valid:
              ref[i] = "train"
            else :
              ref[i] = "valid"
        else:
            ref[i] = "test"

ref_train = np.array([i for i in range(len(ref)) if ref[i]=="train"])
ref_valid = np.array([i for i in range(len(ref)) if ref[i]=="valid"])
ref_test = np.array([i for i in range(len(ref)) if ref[i]=="test"])

In [None]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'

# Data loading & preprocessing

In [None]:
# labels

y = []
valid_id = []

with open('data/graph_labels.txt', "r") as f1:
  for line in f1:
    s1, s2 = line.strip().split(',')
    if len(s2.strip())>0:
      y.append(int(s2))
    else :
      valid_id.append(s1)

y = np.array(y)
y_train, y_valid = y[:-treshold_valid], y[-treshold_valid:]

In [None]:
# graph features

adj, adj_weight, features = load_data()

(
    adj_train, adj_valid, adj_test, 
    features_train, features_valid, features_test, 
    adj_shapes_train, adj_shapes_valid, adj_shapes_test
                                                        ) = graph_train_valid_test(adj, adj_weight, features, ref_train, ref_valid, ref_test)

# Essai avec GTN

In [None]:
model = GTN(90, 64, 0.2, 18).to(device)
model = GNN(90, 64, 0.2, 18).to(device)

In [None]:
def process_batch_data(features, adj_shapes, adj, indices, batchSize, incr_i, device):
    features_ = np.array(features)[indices]
    features_ = np.vstack(features_)
    features_ = torch.FloatTensor(features_).to(device)

    cum_nodes = adj_shapes[indices]
    idx_batch = np.repeat(np.arange(
        batchSize if incr_i%batchSize==0 else incr_i%batchSize
        ), cum_nodes)
    idx_batch = torch.LongTensor(idx_batch).to(device)

    adj_ = adj[indices]
    adj_ = sp.block_diag(adj_)
    adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    return features_, adj_, idx_batch, cum_nodes

In [None]:
pops = np.array([440.,  50., 939.,  60., 112., 625., 202.,  74., 998.,  57.,  43.,305.,  44.,  59., 548., 226.,  60.,  46.])
weights = 1/pops
weights = weights/np.mean(weights)
weights = torch.tensor(weights, dtype=torch.float32).to('cuda')
balanced_criterion = nn.CrossEntropyLoss(weight=weights)

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    lr=1,
    # lr=1e-3,
    weight_decay=1e-5
    )
epochs = 70
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
                lr_lambda=lambda epoch: 1e-4 + (1e-4 - 5e-3) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

batchSize = 64 #training batchsize

patience = 50
patience_count = 0

best_valid_loss = 300
previous_epoch_loss_valid = 300

res = list()
shuffle_train = np.arange(len(ref_train))
shuffle_valid = np.arange(len(ref_valid))

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []
    epoch_nll, epoch_nll_valid = [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)

    if epoch == 50 :
        criterion = balanced_criterion

    # Train
    model.train()
    for i in range(0, len(ref_train), batchSize):
        incr_i = min(i+batchSize, len(ref_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        features_, adj_, idx_batch, cum_nodes = process_batch_data(features_train, adj_shapes_train, adj_train, indices, batchSize, incr_i, device)

        optimizer.zero_grad()
        output = model(features_, adj_, idx_batch)#, cum_nodes)

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        epoch_nll.append(nn.CrossEntropyLoss()(output, targets_).item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

    # Validation
    batchSizeVal=64
    model.eval()
    for i in range(0, len(ref_valid), int(batchSizeVal)):
        incr_i = min(i+int(batchSizeVal), len(ref_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        features_, adj_, idx_batch, cum_nodes = process_batch_data(features_valid, adj_shapes_valid, adj_valid, indices, batchSizeVal, incr_i, device)

        output = model(features_, adj_, idx_batch)#, cum_nodes)

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        epoch_nll_valid.append(nn.CrossEntropyLoss()(output, targets_).item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

    scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    tqdm.write(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}'
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        res = list()
        shuffle_test = np.arange(len(ref_test))

        for i in range(0, len(ref_test), int(batchSizeVal)):
            model.eval()

            incr_i = min(i+int(batchSizeVal), len(ref_test))
            indices = shuffle_test[i: incr_i]

            features_, adj_, idx_batch, cum_nodes = process_batch_data(features_test, adj_shapes_test, adj_test, indices, batchSizeVal, incr_i, device)
            
            output = model(features_, adj_, idx_batch)#, cum_nodes)

            res.append(output.to('cpu').detach().numpy())

    if epoch_loss_valid < best_valid_loss :
        torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        best_valid_loss = epoch_loss_valid 
        print('oui')

best valid loss so far : 1.68\
1.7445 en submission\
with lr=1e-3,
    weight_decay=1e-5
epochs = 70
batchSize = 64

In [None]:
import csv

sub = list(map(lambda x: np.array(nn.Softmax(dim=1)(torch.tensor(x))), res))

y_pred = np.concatenate(sub, axis=0)
y_pred.shape

proteins_test = list()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            proteins_test.append(t[0])

with open('graphedemerde70.csv', 'w') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    lst = list()
    for i in range(18):
        lst.append('class'+str(i))
    lst.insert(0, "name")
    writer.writerow(lst)
    for i, protein in enumerate(proteins_test):
        lst = y_pred[i,:].tolist()
        lst.insert(0, protein)
        writer.writerow(lst)