In [1]:
# ## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# # from google.colab import drive

# # drive.mount('/content/gdrive/', force_remount=True)
# # ../gdrive/MyDrive/data
# %cd zetaFold

In [24]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions
from dataloader import get_loader

from sklearn.metrics import log_loss


In [25]:
n_labels = 4888
treshold_valid = 733

ref = dict()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) != 0:
            if len([_ for _ in ref.values() if _ == "train"]) < n_labels - treshold_valid:
              ref[i] = "train"
            else :
              ref[i] = "valid"
        else:
            ref[i] = "test"

ref_train = np.array([i for i in range(len(ref)) if ref[i]=="train"])
ref_valid = np.array([i for i in range(len(ref)) if ref[i]=="valid"])
ref_test = np.array([i for i in range(len(ref)) if ref[i]=="test"])

In [26]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'

version = "train"
# version = "valid"


In [4]:
y = []
valid_id = []

with open('data/graph_labels.txt', "r") as f1:
  for line in f1:
    s1, s2 = line.strip().split(',')
    if len(s2.strip())>0:
      y.append(int(s2))
    else :
      valid_id.append(s1)

y = np.array(y)
y_train, y_valid = y[:-treshold_valid], y[-treshold_valid:]

# SEQUENCES ONLY : Finetune DistillProtbert Model

In [10]:
sequences = [] 
with open('data/sequences.txt', "r") as f1:
  for line in f1:
    sequences.append(' '.join(list(line[:-1])))

sequences = np.array(sequences)
sequences_train, sequences_valid, sequences_test = sequences[ref_train], sequences[ref_valid], sequences[ref_test]

In [7]:
from transformers import BertModel, BertTokenizer

PRE_TRAINED_MODEL_NAME = 'yarongef/DistilProtBert'
tokenizer = BertTokenizer.from_pretrained(PRE_TRAINED_MODEL_NAME, do_lower_case=False)

Downloading:   0%|          | 0.00/80.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/589 [00:00<?, ?B/s]

In [8]:
## Define probert-based classifier
 
class ProteinClassifier(nn.Module):
    def __init__(self, n_classes):
        super(ProteinClassifier, self).__init__()
        self.bert = BertModel.from_pretrained(PRE_TRAINED_MODEL_NAME).to(device)
        self.classifier = nn.Linear(self.bert.config.hidden_size, n_classes).to(device)
        
    def forward(self, input_ids, attention_mask):
        input_ids, attention_mask = input_ids.to(device), attention_mask.to(device)
        output = self.bert(
          input_ids=input_ids,
          attention_mask=attention_mask
        )
        output = output.last_hidden_state[:, 0, :]
        output = nn.Dropout(0.2)(output) # 0.3 meilleure submission
        output = nn.ReLU()(output)
        output = self.classifier(output)
        # output = nn.ReLU()(output)
        # return output

        return nn.LogSoftmax(dim=1)(output)

# model = ProteinClassifier(18).to(device)

# for module in model.bert.encoder.layer[0:-1]:
#     for param in module.parameters():
#         param.requires_grad = False

# GRAPHE : 

In [6]:
## Prepare graph features

# 1. Full graph-related database in arrays

adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = np.array([adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))])

features = np.array(features, dtype=object)

adj_train, adj_valid, adj_test = adj[ref_train], adj[ref_valid], adj[ref_test]
features_train, features_valid, features_test = features[ref_train], features[ref_valid], features[ref_test]
adj_shapes_train, adj_shapes_valid, adj_shapes_test = adj_shapes[ref_train], adj_shapes[ref_valid], adj_shapes[ref_test]


# features = features[:n_labels] if version != "valid" else features[n_labels:]
# adj = adj[:n_labels] if version != "valid" else adj[n_labels:]

## 2. Then, we load batchs thanks to "indexs" key of dataloader element

## Example : loading only the first batch of graph-related data
# for e in data_loader:
#     break
# batch_indices = e['indexs'] ## Indexs that are part of the batch

# features_ = np.array(features)[batch_indices]
# features_ = np.vstack(features_)
# features_ = torch.FloatTensor(features_).to(device)

# adj_ = adj[batch_indices]
# adj_ = sp.block_diag(adj_)
# adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

# cum_nodes = adj_shapes[batch_indices]
# idx_batch = np.repeat(np.arange(
#     len(batch_indices)
#     ), cum_nodes)
# idx_batch = torch.LongTensor(idx_batch).to(device)

# # With GNN model

# model = GNN(...).to(device)
# output = model(features_, adj_, idx_batch)

In [27]:
## Prepare graph features

# 1. Full graph-related database in arrays

adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = np.array([adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))])

features = np.array(features, dtype=object)

adj_train, adj_valid, adj_test = adj[ref_train], adj[ref_valid], adj[ref_test]
features_train, features_valid, features_test = features[ref_train], features[ref_valid], features[ref_test]
adj_shapes_train, adj_shapes_valid, adj_shapes_test = adj_shapes[ref_train], adj_shapes[ref_valid], adj_shapes[ref_test]


# features = features[:n_labels] if version != "valid" else features[n_labels:]
# adj = adj[:n_labels] if version != "valid" else adj[n_labels:]

## 2. Then, we load batchs thanks to "indexs" key of dataloader element

## Example : loading only the first batch of graph-related data
# for e in data_loader:
#     break
# batch_indices = e['indexs'] ## Indexs that are part of the batch

shuffle_train = np.arange(len(sequences_train))
indices = shuffle_train[:32]

features_ = np.array(features_train)[indices]
features_ = np.vstack(features_)
features_ = torch.FloatTensor(features_).to(device)

adj_ = adj_train[indices]
adj_ = sp.block_diag(adj_)
adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

cum_nodes = adj_shapes_train[indices]
idx_batch = np.repeat(np.arange(
    len(indices)
    ), cum_nodes)
idx_batch = torch.LongTensor(idx_batch).to(device)

# # With GNN model

# model = GNN(...).to(device)
# output = model(features_, adj_, idx_batch)

In [112]:
# OLD

# features_ = np.array(features_train)[indices]
# features_ = np.vstack(features_)
# features_ = torch.FloatTensor(features_).to(device)
# features_.shape

torch.Size([8221, 86])

In [126]:
# NEW

# layer = nn.TransformerEncoderLayer(d_model=86, nhead=2, dim_feedforward=256, batch_first=True)

# features_ = np.array(features_train)[indices]
# features_ = torch.nn.utils.rnn.pad_sequence(
#     [torch.tensor(f, dtype=torch.float32) for f in features_],
#     batch_first=True
#     )
# mask = torch.nn.utils.rnn.pad_sequence([torch.zeros(e) for e in cum_nodes], batch_first=True, padding_value=1).type(torch.bool)#
# features_ = layer(features_, src_key_padding_mask=mask)
# features_ = torch.vstack([features_[j,:c] for j,c in enumerate(cum_nodes)])
# print(features_.shape)

# for k,v in layer.state_dict().items():
#     # print(k, v.shape)
#     ()
# 258*86+258+86*86+86+256*86+256+86*256+86*5    

torch.Size([8221, 86])
74646


In [22]:
class GNN(nn.Module):
    """
    Simple message passing model that consists of 2 message passing layers
    and the sum aggregation function
    """
    def __init__(self, input_dim, hidden_dim, dropout, n_class):
        super(GNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # self.fc2 = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, n_class)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_in, adj, idx):
        # first message passing layer
        x = self.fc1(x_in)
        x = self.relu(torch.mm(adj, x))
        x = self.dropout(x)

        # second message passing layer
        x = self.fc2(x)
        x = self.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.bn(out)

        # mlp to produce output
        out = self.relu(self.fc3(out))
        out = self.dropout(out)
        out = self.fc4(out)
        
        return F.log_softmax(out, dim=1)

model = GNN(86, 64, 0.2, 18).to(device)

In [132]:
class GNN(nn.Module):
    """
    Simple message passing model that consists of 2 message passing layers
    and the sum aggregation function
    """
    def __init__(self, input_dim, hidden_dim, dropout, n_class):
        super(GNN, self).__init__()
        # self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc1 = nn.TransformerEncoderLayer(d_model=86, nhead=2, dim_feedforward=256, batch_first=True)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, n_class)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_in, adj, idx, cum_nodes):
        # first message passing layer
        pad_mask = torch.nn.utils.rnn.pad_sequence([torch.zeros(e) for e in cum_nodes], batch_first=True, padding_value=1).type(torch.bool)
        x = self.fc1(x_in, src_key_padding_mask=pad_mask)
        print('1', x.shape)
        x = torch.vstack([x[j,:c] for j,c in enumerate(cum_nodes)])
        print('2', x.shape)
        x = self.relu(torch.mm(adj, x))
        x = self.dropout(x)

        # second message passing layer
        x = self.fc2(x)
        x = self.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.bn(out)

        # mlp to produce output
        out = self.relu(self.fc3(out))
        out = self.dropout(out)
        out = self.fc4(out)
        
        return F.log_softmax(out, dim=1)

model = GNN(86, 64, 0.2, 18).to(device)

In [133]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    # lr=1,
    lr=1e-3,
    # weight_decay=0.0001
    )
epochs = 50
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-4) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

shuffle_train = np.arange(len(sequences_train))
shuffle_valid = np.arange(len(sequences_valid))

batchSize = 64

patience = 50
patience_count = 0

previous_epoch_loss_valid = 300

res = list()

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)


    # Train
    # batchSize=2
    model.train()
    for i in range(0, len(sequences_train), batchSize):

        incr_i = min(i+batchSize, len(sequences_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        # features_ = np.array(features_train)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)
        features_ = np.array(features_train)[indices]
        features_ = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(f, dtype=torch.float32) for f in features_],
            batch_first=True
            ).to(device)

        cum_nodes = adj_shapes_train[indices]
        idx_batch = np.repeat(np.arange(
            batchSize if incr_i%batchSize==0 else incr_i%batchSize
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        adj_ = adj_train[indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        optimizer.zero_grad()
        output = model(features_, adj_, idx_batch, cum_nodes)
        
        del features_
        del adj_
        del idx_batch

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        del output

    # Validation
    # batchSize=64
    model.eval()
    for i in range(0, len(sequences_valid), int(batchSize)):

        incr_i = min(i+int(batchSize), len(sequences_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)

        # features_ = np.array(features_train)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)
        features_ = np.array(features_train)[indices]
        features_ = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(f, dtype=torch.float32) for f in features_],
            batch_first=True
            ).to(device)

        cum_nodes = adj_shapes_valid[indices]
        idx_batch = np.repeat(np.arange(
            int(batchSize) if incr_i%int(batchSize)==0 else incr_i%int(batchSize)
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        adj_ = adj_valid[indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        output = model(features_, adj_, idx_batch, cum_nodes)

        del features_
        del adj_
        del idx_batch

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

        del output

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    tqdm.write(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}'
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        if patience_count == 0:
          torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        del res
        res = list()
        shuffle_test = np.arange(len(sequences_test))

        for i in range(0, len(sequences_test), int(batchSize)):
            model.eval()

            incr_i = min(i+int(batchSize), len(sequences_test))
            indices = shuffle_test[i: incr_i]

            features_ = np.array(features_test)[indices]
            features_ = np.vstack(features_)
            features_ = torch.FloatTensor(features_).to(device)

            cum_nodes = adj_shapes_test[indices]
            idx_batch = np.repeat(np.arange(
                batchSize if incr_i%batchSize==0 else incr_i%batchSize
                ), cum_nodes)
            idx_batch = torch.LongTensor(idx_batch).to(device)
            
            adj_ = adj_test[indices]
            adj_ = sp.block_diag(adj_)
            adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

            output = model(features_, adj_, idx_batch)

            res.append(output.to('cpu').detach().numpy())

            del output

  0%|          | 0/50 [00:00<?, ?it/s]

1 torch.Size([64, 796, 86])
2 torch.Size([14822, 86])





RuntimeError: mat1 and mat2 shapes cannot be multiplied (14822x86 and 64x64)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

k = 0

résumé = []
for k in range(18):
    bol = tf.argmax(y_test, axis=1) == k
    preds = model(X_test)[bol]
    targets = y_test[bol]
    accuracy_preds = (tf.argmax(preds, axis=1)==k).numpy().mean()
    nll = log_loss(y_true=targets, y_pred=preds)
    résumé.append((accuracy_preds, nll))

acc = [tup[0] for tup in résumé]
nll = [tup[1] for tup in résumé]

fig, axs = plt.subplots(1, 3, tight_layout=True, figsize=(14, 6))
axs[0].bar(np.arange(len(acc)), acc)
axs[0].set_title('Accuracy')
plt.xticks(np.arange(len(acc)))
axs[1].bar(np.arange(len(nll)), nll)
axs[1].set_title('Negative Log Likelihood')
plt.xticks(np.arange(len(acc)))
axs[2].bar(np.arange(len(nll)), [440.,  50., 939.,  60., 112., 625., 202.,  74., 998.,  57.,  43.,305.,  44.,  59., 548., 226.,  60.,  46.])
axs[2].set_title('Pops')
plt.xticks(np.arange(len(acc)))
plt.show()

In [None]:
model = GNN(86, 64, 0.1, 18).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
epochs = 50
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-3) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

data_loader = get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=tokenizer, max_len=600, batch_size=64, shuffle=False, version=version, drop_last=True)
                
pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_log_loss, epoch_accuracy = [], [], []

    for batch_num, e in enumerate(data_loader):
        
        batch_indices = e['indexs']
        optimizer.zero_grad()

        target = F.one_hot(e['target'], 18).float().to(device)

        # Compute graph forward
        features_ = np.array(features)[batch_indices]
        features_ = np.vstack(features_)
        features_ = torch.FloatTensor(features_).to(device)

        adj_ = adj[batch_indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        cum_nodes = adj_shapes[batch_indices]
        idx_batch = np.repeat(np.arange(
            len(batch_indices)
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        output = model(features_, adj_, idx_batch)

        # Backward
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        target = target.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()
        nll_loss = log_loss(target, output)
        epoch_log_loss.append(nll_loss)

        accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_log_loss = np.array(epoch_log_loss).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()

    pbar.set_description(
        f'Epoch {epoch}, cel: {round(epoch_loss, 4)}, '
        f'nll: {round(epoch_log_loss, 4)}, acc: {round(epoch_accuracy, 4)}'
        )

4888 4888 4888


Epoch 8, cel: 2.3809, nll: 2.3809, acc: 0.2111:  18%|█▊        | 9/50 [00:46<03:32,  5.19s/it]


KeyboardInterrupt: ignored

In [None]:
torch.save({"state": model.state_dict(),}, "graphe_pretrain.pt")

# MULTICULTURAL MODEL

In [49]:
# class ConcatModel(nn.Module):
#     def __init__(self, out_dim, feature_dim, graph_dim):
#         super(ConcatModel, self).__init__()
#         # self.seqModel = ProteinClassifier(out_dim).to(device)
#         # self.graphModel = GNN(feature_dim, graph_dim, dropout=0.25, n_class=out_dim).to(device)

#         self.seqModel = ProteinClassifier(18).to(device)

#         # self.classifier = nn.Linear(out_dim, 18)
#         # self.activation = nn.LogSoftmax(dim=1)

#     def forward(self, x_in, adj, idx, input_ids, attention_mask):
#         seqEmbedding = self.seqModel(input_ids, attention_mask)
#         # graphEmbedding = self.graphModel(x_in, adj, idx)
        
#         # out = torch.concat([graphEmbedding, seqEmbedding], dim=1)
        
#         # out = self.classifier(out)
        
#         # return self.activation(out)
#         return nn.LogSoftmax(dim=1)(seqEmbedding)

# model = ConcatModel(64, 86, 128).to(device)
model = ProteinClassifier(18).to(device)

Some weights of the model checkpoint at yarongef/DistilProtBert were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at yarongef/DistilProtBert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bi

In [53]:
for module in model.bert.encoder.layer[:-1]:
    for param in module.parameters():
        param.requires_grad = False

for module in model.bert.encoder.layer[-1:]:
    # module._modules["output"].dense.weight.data.normal_(mean=0.0, std=0.1)
    module._modules["output"].dense.weight.data.uniform_(-0.1, 0.1)

In [45]:
encode = lambda s : tokenizer.encode_plus(
            s,
            truncation=True,
            add_special_tokens=True,
            max_length=600,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

In [54]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    # lr=1,
    lr=3e-5,
    # weight_decay=0.0001
    )
epochs = 20
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-4) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

shuffle_train = np.arange(len(sequences_train))
shuffle_valid = np.arange(len(sequences_valid))

batchSize = 4

patience = 2
patience_count = 0

previous_epoch_loss_valid = 300

res = list()

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)


    # Train
    batchSize=2
    model.train()
    for i in range(0, len(sequences_train), batchSize):

        incr_i = min(i+batchSize, len(sequences_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_train[indices]]).to(device)

        attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_train[indices]]).to(device)

        # features_ = np.array(features_train)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)

        # cum_nodes = adj_shapes_train[indices]
        # idx_batch = np.repeat(np.arange(
        #     batchSize if incr_i%batchSize==0 else incr_i%batchSize
        #     ), cum_nodes)
        # idx_batch = torch.LongTensor(idx_batch).to(device)
        
        # adj_ = adj_train[indices]
        # adj_ = sp.block_diag(adj_)
        # adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        optimizer.zero_grad()
        output = model(sequences_, attention_mask_)
        
        # del features_
        # del adj_
        # del idx_batch
        del sequences_
        del attention_mask_

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        del output

    # Validation
    batchSize=4
    model.eval()
    for i in range(0, len(sequences_valid), int(batchSize/4)):

        incr_i = min(i+int(batchSize/4), len(sequences_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_valid[indices]]).to(device)

        attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_valid[indices]]).to(device)

        # features_ = np.array(features_valid)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)

        # cum_nodes = adj_shapes_valid[indices]
        # idx_batch = np.repeat(np.arange(
        #     int(batchSize/4) if incr_i%int(batchSize/4)==0 else incr_i%int(batchSize/4)
        #     ), cum_nodes)
        # idx_batch = torch.LongTensor(idx_batch).to(device)
        
        # adj_ = adj_valid[indices]
        # adj_ = sp.block_diag(adj_)
        # adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        output = model(sequences_, attention_mask_)

        # del features_
        # del adj_
        # del idx_batch
        del sequences_
        del attention_mask_

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

        del output

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    tqdm.write(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}'
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        if patience_count == 0:
          torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        del res
        res = list()
        shuffle_test = np.arange(len(sequences_test))

        for i in range(0, len(sequences_test), int(batchSize/2)):
            model.eval()

            incr_i = min(i+int(batchSize/2), len(sequences_test))
            sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_test[shuffle_test[i: incr_i]]])
            attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_test[shuffle_test[i: incr_i]]])

            output = model(sequences_, attention_mask_)

            del sequences_
            del attention_mask_
            
            res.append(output.to('cpu').detach().numpy())

            del output

  0%|          | 0/20 [03:52<?, ?it/s]


Epoch 0:
train loss: 1.9487, valid loss: 1.829, delta loss: 0.1196,
acc train: 0.4297, acc valid: 0.4216


  5%|▌         | 1/20 [08:15<1:23:19, 263.11s/it]


Epoch 1:
train loss: 1.4058, valid loss: 1.4701, delta loss: -0.0643,
acc train: 0.6049, acc valid: 0.5675


 10%|█         | 2/20 [12:38<1:18:54, 263.01s/it]


Epoch 2:
train loss: 1.1432, valid loss: 1.3205, delta loss: -0.1773,
acc train: 0.6747, acc valid: 0.6153


 15%|█▌        | 3/20 [17:01<1:14:30, 262.99s/it]


Epoch 3:
train loss: 0.9378, valid loss: 1.2714, delta loss: -0.3336,
acc train: 0.7435, acc valid: 0.6412


 20%|██        | 4/20 [21:23<1:10:07, 263.00s/it]


Epoch 4:
train loss: 0.7524, valid loss: 1.2048, delta loss: -0.4524,
acc train: 0.7952, acc valid: 0.6385


 25%|██▌       | 5/20 [25:46<1:05:43, 262.93s/it]


Epoch 5:
train loss: 0.5916, valid loss: 1.165, delta loss: -0.5734,
acc train: 0.8446, acc valid: 0.6589


 30%|███       | 6/20 [30:09<1:01:20, 262.86s/it]


Epoch 6:
train loss: 0.4429, valid loss: 1.1887, delta loss: -0.7458,
acc train: 0.8859, acc valid: 0.6712


 35%|███▌      | 7/20 [34:03<54:52, 253.29s/it]


Epoch 7:
train loss: 0.3187, valid loss: 1.1319, delta loss: -0.8131,
acc train: 0.9223, acc valid: 0.6849


 40%|████      | 8/20 [38:26<51:16, 256.37s/it]


Epoch 8:
train loss: 0.2109, valid loss: 1.2003, delta loss: -0.9895,
acc train: 0.9569, acc valid: 0.6698


 45%|████▌     | 9/20 [42:19<51:44, 282.21s/it]


Epoch 9:
train loss: 0.1417, valid loss: 1.2104, delta loss: -1.0687,
acc train: 0.9757, acc valid: 0.6835
Early stopping end of epoch 9





In [14]:
bs= 4, dernière couche sans weight réinit

   0%|          | 0/20 [03:36<?, ?it/s]
Epoch 0:
train loss: 2.0976, valid loss: 1.9806, delta loss: 0.117,
acc train: 0.3805, acc valid: 0.3342
  5%|▌         | 1/20 [07:40<1:18:14, 247.06s/it]
Epoch 1:
train loss: 1.5981, valid loss: 1.6424, delta loss: -0.0443,
acc train: 0.5566, acc valid: 0.5198
 10%|█         | 2/20 [11:44<1:13:35, 245.30s/it]
Epoch 2:
train loss: 1.3427, valid loss: 1.4754, delta loss: -0.1327,
acc train: 0.6323, acc valid: 0.558
 15%|█▌        | 3/20 [15:48<1:09:20, 244.76s/it]
Epoch 3:
train loss: 1.1706, valid loss: 1.3436, delta loss: -0.173,
acc train: 0.6741, acc valid: 0.6003
 20%|██        | 4/20 [19:52<1:05:11, 244.48s/it]
Epoch 4:
train loss: 1.0137, valid loss: 1.255, delta loss: -0.2413,
acc train: 0.7238, acc valid: 0.6357
 25%|██▌       | 5/20 [23:56<1:01:04, 244.29s/it]
Epoch 5:
train loss: 0.8842, valid loss: 1.2542, delta loss: -0.37,
acc train: 0.7601, acc valid: 0.6303
 30%|███       | 6/20 [28:00<56:59, 244.26s/it]
Epoch 6:
train loss: 0.7609, valid loss: 1.1673, delta loss: -0.4064,
acc train: 0.7885, acc valid: 0.6821
 35%|███▌      | 7/20 [32:04<52:54, 244.19s/it]
Epoch 7:
train loss: 0.6425, valid loss: 1.1662, delta loss: -0.5236,
acc train: 0.8265, acc valid: 0.6726
 40%|████      | 8/20 [36:08<48:49, 244.15s/it]
Epoch 8:
train loss: 0.5414, valid loss: 1.1611, delta loss: -0.6197,
acc train: 0.8583, acc valid: 0.6903
 45%|████▌     | 9/20 [40:12<44:45, 244.10s/it]
Epoch 9:
train loss: 0.442, valid loss: 1.1847, delta loss: -0.7427,
acc train: 0.8867, acc valid: 0.6739
 55%|█████▌    | 11/20 [43:46<34:13, 228.21s/it]
Epoch 10:
train loss: 0.3472, valid loss: 1.185, delta loss: -0.8377,
acc train: 0.9175, acc valid: 0.6876

SyntaxError: ignored

In [None]:
bs = 4 en changeant les poids des deux dernières couches, avec 0.2 en std

train loss: 0.8832, valid loss: 1.2811, delta loss: -0.3979,

In [None]:
bs = 4 en changeant les poids des deux dernières couches, avec 0.1 en std

train loss: 0.8925, valid loss: 1.2393, delta loss: -0.3468,

In [None]:
bs = 4 sans changer les poids de la dernière couche

train loss: 0.79, valid loss: 1.2503, delta loss: -0.4603,

In [55]:
sub = list(map(lambda x: np.array(nn.Softmax(dim=1)(torch.tensor(x))), res))

y_pred = np.concatenate(sub, axis=0)
y_pred.shape

proteins_test = list()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            proteins_test.append(t[0])

In [56]:
import csv

with open('8epochs_samedi_uniform_bs2_lr3e5.csv', 'w') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    lst = list()
    for i in range(18):
        lst.append('class'+str(i))
    lst.insert(0, "name")
    writer.writerow(lst)
    for i, protein in enumerate(proteins_test):
        lst = y_pred[i,:].tolist()
        lst.insert(0, protein)
        writer.writerow(lst)

In [None]:
Epoch 0:
train loss: 1.8907, valid loss: 1.6979,
acc train: 0.4579acc valid: 0.485
Epoch 1:
train loss: 1.3323, valid loss: 1.4126,
acc train: 0.622acc valid: 0.5913
Epoch 2:
train loss: 1.0555, valid loss: 1.2206,
acc train: 0.7087acc valid: 0.6335
Epoch 3:
train loss: 0.8302, valid loss: 1.1845,
acc train: 0.7745acc valid: 0.6662
Epoch 4:
train loss: 0.649, valid loss: 1.2185,
acc train: 0.8267acc valid: 0.6499
Epoch 5:
train loss: 0.4991, valid loss: 1.1734,
acc train: 0.8745acc valid: 0.6703
Epoch 6:
train loss: 0.3834, valid loss: 1.1397,
acc train: 0.909acc valid: 0.688
Epoch 7:
train loss: 0.2994, valid loss: 1.1628,
acc train: 0.9351acc valid: 0.6839
Epoch 8:
train loss: 0.238, valid loss: 1.186,
acc train: 0.9553acc valid: 0.6826
Early stopping end of epoch 8

In [None]:
stop

In [None]:
  5%|▌         | 1/20 [03:28<1:05:59, 208.38s/it]
Epoch 0:
train loss: 1.9055, valid loss: 1.6867,
acc train: 0.4484acc valid: 0.4946
 10%|█         | 2/20 [06:56<1:02:28, 208.27s/it]
Epoch 1:
train loss: 1.3533, valid loss: 1.4342,
acc train: 0.6246acc valid: 0.5681
 15%|█▌        | 3/20 [10:24<59:01, 208.30s/it]  
Epoch 2:
train loss: 1.0844, valid loss: 1.3051,
acc train: 0.6972acc valid: 0.6131
 20%|██        | 4/20 [13:52<55:30, 208.18s/it]
Epoch 3:
train loss: 0.8398, valid loss: 1.248,
acc train: 0.7732acc valid: 0.6349
 25%|██▌       | 5/20 [17:21<52:02, 208.15s/it]
Epoch 4:
train loss: 0.6354, valid loss: 1.2457,
acc train: 0.8288acc valid: 0.6471
 30%|███       | 6/20 [20:49<48:33, 208.13s/it]
Epoch 5:
train loss: 0.4625, valid loss: 1.1694,
acc train: 0.8825acc valid: 0.6662
 35%|███▌      | 7/20 [24:17<45:05, 208.15s/it]
Epoch 6:
train loss: 0.3168, valid loss: 1.1416,
acc train: 0.9216acc valid: 0.6771
 35%|███▌      | 7/20 [27:45<45:05, 208.15s/it]
Epoch 7:
train loss: 0.2154, valid loss: 1.1496,
acc train: 0.9583acc valid: 0.6798
 40%|████      | 8/20 [31:15<46:52, 234.40s/it]
Epoch 8:
train loss: 0.1402, valid loss: 1.204,
acc train: 0.9791acc valid: 0.6812
Early stopping end of epoch 8

In [None]:
checkpoint = torch.load("../gdrive/MyDrive/data/last_model_checkpoint.pt")
model.load_state_dict(checkpoint['state'])
model.eval()
()

()

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import os 
os.environ ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:516"

In [None]:
model = model.to('cuda')

In [None]:
# Test
batchSize = 2
res = list()

indices = np.arange(len(sequences_test))

for i in tqdm(range(0, len(sequences_test), int(batchSize/2))):
    model.eval()

    incr_i = min(i+int(batchSize/2), len(sequences_test))
    # indices = shuffle_test[i: incr_i]

    # targets_ = torch.tensor(y_valid[indices])
    # targets_ = F.one_hot(targets_, 18).float().to(device)
    
    sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_test[indices]])

    attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_test[indices]])

    # features_ = np.array(features_test)[indices]
    # features_ = np.vstack(features_)
    # features_ = torch.FloatTensor(features_).to(device)

    # cum_nodes = adj_shapes_test[indices]
    # idx_batch = np.repeat(np.arange(
    #     int(batchSize/4) if incr_i%int(batchSize/4)==0 else incr_i%int(batchSize/4)
    #     ), cum_nodes)
    # idx_batch = torch.LongTensor(idx_batch).to(device)
    
    # adj_ = adj_test[indices]
    # adj_ = sp.block_diag(adj_)
    # adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    output = model(sequences_, attention_mask_)

    del features_
    del adj_
    del idx_batch
    del sequences_
    del attention_mask_
    
    res.append(output)

    del output

  0%|          | 0/1223 [00:05<?, ?it/s]


OutOfMemoryError: ignored

In [None]:
arg1 = np.arange(
  int(batchSize/4) if incr_i%int(batchSize/4)==0 else incr_i%int(batchSize/4)
  )

In [None]:
cum_nodes.shape

(1223,)

In [None]:
  5%|▌         | 1/20 [03:20<1:03:26, 200.36s/it]
Epoch 0:
train loss: 1.8427, valid loss: 1.6749,
acc train: 0.4519acc valid: 0.4986
 10%|█         | 2/20 [06:40<1:00:03, 200.17s/it]
Epoch 1:
train loss: 1.3959, valid loss: 1.4417,
acc train: 0.5976acc valid: 0.5938
 10%|█         | 2/20 [10:00<1:00:03, 200.17s/it]
Epoch 2:
train loss: 1.1997, valid loss: 1.4535,
acc train: 0.6586acc valid: 0.5978
 15%|█▌        | 3/20 [13:21<1:15:40, 267.10s/it]
Epoch 3:
train loss: 1.046, valid loss: 1.5351,
acc train: 0.7006acc valid: 0.5557
Early stopping end of epoch 3

In [None]:
 10%|█         | 1/10 [03:22<30:19, 202.12s/it]
Epoch 0, train loss: 1.8413, valid loss: 2.0762
, acc train: 0.4657acc train: 0.3791
 20%|██        | 2/10 [06:44<26:56, 202.10s/it]
Epoch 1, train loss: 1.6625, valid loss: 1.7745
, acc train: 0.5146acc train: 0.4742
 20%|██        | 2/10 [10:06<40:24, 303.07s/it]
Epoch 2, train loss: 1.5559, valid loss: 1.8608
, acc train: 0.553acc train: 0.5068
Early stopping end of epoch 2

16

In [None]:
10%|█         | 1/10 [03:22<30:19, 202.14s/it]
Epoch 0:
train loss: 2.9345, valid loss: 2.6558,
acc train: 0.2496acc valid: 0.2908
 20%|██        | 2/10 [06:44<26:56, 202.11s/it]
Epoch 1:
train loss: 2.2657, valid loss: 2.4044,
acc train: 0.3677acc valid: 0.3152
 30%|███       | 3/10 [10:06<23:34, 202.09s/it]
Epoch 2:
train loss: 1.9986, valid loss: 2.3356,
acc train: 0.4311acc valid: 0.3356
 40%|████      | 4/10 [13:28<20:12, 202.01s/it]
Epoch 3:
train loss: 1.8384, valid loss: 1.9591,
acc train: 0.4693acc valid: 0.4389
 40%|████      | 4/10 [16:50<25:15, 252.53s/it]
Epoch 4:
train loss: 1.6969, valid loss: 2.1203,
acc train: 0.5152acc valid: 0.4524
Early stopping end of epoch 4

In [None]:
# for batch_num, e in enumerate(get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
#                 tokenizer=tokenizer, max_len=600, batch_size=8, shuffle=True, version=version, drop_last=True)):
        
#         batch_indices = e['indexs']
        

#         target = F.one_hot(e['target'], 18).float().to(device)

#         #Compute sequence forward
#         input = e['input_ids'].to(device)
#         src_mask = e['attention_mask'].to(device)

#         # Compute graph forward
#         features_ = np.array(features)[batch_indices]
#         features_ = np.vstack(features_)
#         features_ = torch.FloatTensor(features_).to(device)

#         adj_ = adj[batch_indices]
#         adj_ = sp.block_diag(adj_)
#         adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

#         cum_nodes = adj_shapes[batch_indices]
#         idx_batch = np.repeat(np.arange(
#             len(batch_indices)
#             ), cum_nodes)
#         idx_batch = torch.LongTensor(idx_batch).to(device)

#         optimizer.zero_grad()
        
#         output = model(features_, adj_, idx_batch, input, src_mask)

#         # Backward
#         loss = criterion(output, target)
#         loss.backward()
#         optimizer.step()

#         # Compute metrics
#         epoch_loss.append(loss.item())
        
#         target = target.to('cpu').detach().numpy()
#         output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

#         accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
#         epoch_accuracy.append(accuracy)

#     for batch_num, e in enumerate(get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
#                 tokenizer=tokenizer, max_len=600, batch_size=8, shuffle=False, version="valid", drop_last=False)):
      
#         batch_indices = e['indexs']
#         target = F.one_hot(e['target'], 18).float().to(device)

#         #Compute sequence forward
#         input = e['input_ids'].to(device)
#         src_mask = e['attention_mask'].to(device)

#         # Compute graph forward
#         features_ = np.array(features_valid)[batch_indices]
#         features_ = np.vstack(features_)
#         features_ = torch.FloatTensor(features_).to(device)

#         adj_ = adj_valid[batch_indices]
#         adj_ = sp.block_diag(adj_)
#         adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

#         cum_nodes = adj_shapes_valid[batch_indices]
#         idx_batch = np.repeat(np.arange(
#             len(batch_indices)
#             ), cum_nodes)
#         idx_batch = torch.LongTensor(idx_batch).to(device)
        
#         output = model(features_, adj_, idx_batch, input, src_mask)

#         # Compute metrics
#         epoch_loss_valid.append(loss.item())
        
#         target = target.to('cpu').detach().numpy()
#         output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

#         accuracy = (target.argmax(axis=1)==output.argmax(axis=1)).mean()
#         epoch_accuracy_valid.append(accuracy)


#     scheduler.step()

#     epoch_loss = np.array(epoch_loss).mean()
#     epoch_loss_valid = np.array(epoch_loss_valid).mean()
#     epoch_accuracy = np.array(epoch_accuracy).mean()
#     epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

#     tqdm.write(
#         '\n'
#         f'Epoch {epoch}, train loss: {round(epoch_loss, 4)}, '
#         f'valid loss: {round(epoch_loss_valid, 4)}\n, acc train: {round(epoch_accuracy, 4)}'
#         f'acc train: {round(epoch_accuracy_valid, 4)}'
#         )

In [None]:
torch.save({"state": model.state_dict(),}, "1901__5epochs.pt")

In [None]:
version = "valid"

In [None]:
ref = list()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            ref.append(True)
        else:
            ref.append(False)

ref = np.array(ref)
ref = np.logical_not(ref)
ref.shape

In [None]:
is_kept = []
with open('data/graph_labels.txt', "r") as f1:
    for line in f1:
        s1, s2 = line.strip().split(',')
        if len(s2.strip())>0:
            is_kept.append(version != "valid")
        else :
            is_kept.append(version == "valid")

In [None]:
adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = [adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))]

adj = np.array(adj)[np.array(is_kept)]
features = np.array(features, dtype=object)[np.array(is_kept)]
adj_shapes = adj_shapes[np.array(is_kept)]

In [None]:
len(adj), len(features), len(adj_shapes)

In [None]:
data_loader = get_loader(path_documents='data/sequences.txt', path_labels='data/graph_labels.txt', 
                tokenizer=tokenizer, max_len=600, batch_size=8, shuffle=False, version=version, drop_last=False)

In [None]:
model.eval()
model.training

In [None]:
sub = []

for batch_num, e in tqdm(enumerate(data_loader)):
        
    batch_indices = e['indexs']

    #Compute sequence forward
    input = e['input_ids'].to(device)
    src_mask = e['attention_mask'].to(device)

    # Compute graph forward
    features_ = np.array(features)[batch_indices]
    features_ = np.vstack(features_)
    features_ = torch.FloatTensor(features_).to(device)

    adj_ = adj[batch_indices]
    adj_ = sp.block_diag(adj_)
    adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    cum_nodes = adj_shapes[batch_indices]
    idx_batch = np.repeat(np.arange(
        len(batch_indices)
        ), cum_nodes)
    idx_batch = torch.LongTensor(idx_batch).to(device)
    
    output = model(features_, adj_, idx_batch, input, src_mask)
    
    sub.append(output.to('cpu').detach().numpy())

In [None]:
import csv

with open('5epochs_underfitting.csv', 'w') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    lst = list()
    for i in range(18):
        lst.append('class'+str(i))
    lst.insert(0, "name")
    writer.writerow(lst)
    for i, protein in enumerate(proteins_test):
        lst = y_pred[i,:].tolist()
        lst.insert(0, protein)
        writer.writerow(lst)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

k = 0

résumé = []
for k in range(18):
    bol = tf.argmax(y_test, axis=1) == k
    preds = model(X_test)[bol]
    targets = y_test[bol]
    accuracy_preds = (tf.argmax(preds, axis=1)==k).numpy().mean()
    nll = log_loss(y_true=targets, y_pred=preds)
    résumé.append((accuracy_preds, nll))

acc = [tup[0] for tup in résumé]
nll = [tup[1] for tup in résumé]

fig, axs = plt.subplots(1, 3, tight_layout=True, figsize=(14, 6))
axs[0].bar(np.arange(len(acc)), acc)
axs[0].set_title('Accuracy')
plt.xticks(np.arange(len(acc)))
axs[1].bar(np.arange(len(nll)), nll)
axs[1].set_title('Negative Log Likelihood')
plt.xticks(np.arange(len(acc)))
axs[2].bar(np.arange(len(nll)), [440.,  50., 939.,  60., 112., 625., 202.,  74., 998.,  57.,  43.,305.,  44.,  59., 548., 226.,  60.,  46.])
axs[2].set_title('Pops')
plt.xticks(np.arange(len(acc)))
plt.show()

In [None]:
#GTN version brouillon

from torch.nn.utils.rnn import pad_sequence
a = torch.ones(25, 300)
b = torch.ones(22, 300)
c = torch.ones(15, 300)
pad_sequence([a, b, c]).size()



from graph_transformer_pytorch import GraphTransformer

model = GraphTransformer(
    dim = 256,
    depth = 6,
    edge_dim = 512,             # optional - if left out, edge dimensions is assumed to be the same as the node dimensions above
    with_feedforwards = True,   # whether to add a feedforward after each attention layer, suggested by literature to be needed
    gated_residual = True,      # to use the gated residual to prevent over-smoothing
    rel_pos_emb = True          # set to True if the nodes are ordered, default to False
)

nodes = torch.randn(1, 128, 256)
edges = torch.randn(1, 128, 128, 512)
mask = torch.ones(1, 128).bool()

nodes, edges = model(nodes, edges, mask = mask)

nodes.shape # (1, 128, 256) - project to R^3 for coordinates