In [None]:
# ## COLAB AREA

# !git clone https://github.com/AugustinCombes/zetaFold.git
# %pip install transformers
# # from google.colab import drive

# # drive.mount('/content/gdrive/', force_remount=True)
# # ../gdrive/MyDrive/data
# %cd zetaFold

In [1]:
from tqdm import trange, tqdm
import numpy as np
import scipy.sparse as sp

import torch
import torch.nn.functional as F
import torch.nn as nn

from utils import normalize_adjacency, sparse_mx_to_torch_sparse_tensor, load_data, submit_predictions
from dataloader import get_loader

from sklearn.metrics import log_loss

from models.sequence_protbert import ProteinClassifier

In [2]:
n_labels = 4888
treshold_valid = 733

ref = dict()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) != 0:
            if len([_ for _ in ref.values() if _ == "train"]) < n_labels - treshold_valid:
              ref[i] = "train"
            else :
              ref[i] = "valid"
        else:
            ref[i] = "test"

ref_train = np.array([i for i in range(len(ref)) if ref[i]=="train"])
ref_valid = np.array([i for i in range(len(ref)) if ref[i]=="valid"])
ref_test = np.array([i for i in range(len(ref)) if ref[i]=="test"])

In [3]:
device = 'cpu'
# device = 'mps'
# device = 'cuda'


In [4]:
y = []
valid_id = []

with open('data/graph_labels.txt', "r") as f1:
  for line in f1:
    s1, s2 = line.strip().split(',')
    if len(s2.strip())>0:
      y.append(int(s2))
    else :
      valid_id.append(s1)

y = np.array(y)
y_train, y_valid = y[:-treshold_valid], y[-treshold_valid:]

# Data preprocessing

In [None]:
sequences = [] 
with open('data/sequences.txt', "r") as f1:
  for line in f1:
    sequences.append(' '.join(list(line[:-1])))

sequences = np.array(sequences)
sequences_train, sequences_valid, sequences_test = sequences[ref_train], sequences[ref_valid], sequences[ref_test]

In [8]:
## Prepare graph features

# 1. Full graph-related database in arrays

adj, adj_weight, features, edge_features, Flist = load_data()
del Flist
del edge_features

adj = [normalize_adjacency(A, W) for A, W in zip(adj, adj_weight)]
adj_shapes = np.array([at.shape[0] for at in adj])
adj = np.array([adj[idx] + sp.identity(adj_shapes[idx]) for idx in range(len(adj))])

features = np.array(features, dtype=object)

adj_train, adj_valid, adj_test = adj[ref_train], adj[ref_valid], adj[ref_test]
features_train, features_valid, features_test = features[ref_train], features[ref_valid], features[ref_test]
adj_shapes_train, adj_shapes_valid, adj_shapes_test = adj_shapes[ref_train], adj_shapes[ref_valid], adj_shapes[ref_test]

# GRAPHE : 

In [None]:
shuffle_train = np.arange(len(sequences_train))
indices = shuffle_train[:32]

In [None]:
# OLD

class GNN(nn.Module):
    """
    Simple message passing model that consists of 2 message passing layers
    and the sum aggregation function
    """
    def __init__(self, input_dim, hidden_dim, dropout, n_class):
        super(GNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        # self.fc2 = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=8)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, n_class)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_in, adj, idx):
        # first message passing layer
        print('1', x_in.shape)
        x = self.fc1(x_in)
        print('2', x.shape)
        x = self.relu(torch.mm(adj, x))
        x = self.dropout(x)

        # second message passing layer
        x = self.fc2(x)
        x = self.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.bn(out)

        # mlp to produce output
        out = self.relu(self.fc3(out))
        out = self.dropout(out)
        out = self.fc4(out)
        
        return F.log_softmax(out, dim=1)

model = GNN(86, 64, 0.2, 18).to(device)

adj_ = adj_train[indices]
adj_ = sp.block_diag(adj_)
adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

cum_nodes = adj_shapes_train[indices]
idx_batch = np.repeat(np.arange(
    len(indices)
    ), cum_nodes)
idx_batch = torch.LongTensor(idx_batch).to(device)

features_ = np.array(features_train)[indices]
features_ = np.vstack(features_)
features_ = torch.FloatTensor(features_).to(device)

output = model(features_, adj_, idx_batch)

In [None]:
# NEW

class GNN(nn.Module):
    """
    Simple message passing model that consists of 2 message passing layers
    and the sum aggregation function
    """
    def __init__(self, input_dim, hidden_dim, dropout, n_class):
        super(GNN, self).__init__()
        # self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.embedding = nn.Linear(input_dim, hidden_dim)
        self.fc1 = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=2, dim_feedforward=256, batch_first=True)
        # self.fc2 = nn.Linear(input_dim, hidden_dim)
        
        self.fc3 = nn.Linear(hidden_dim, hidden_dim)
        self.fc4 = nn.Linear(hidden_dim, n_class)
        self.bn = nn.BatchNorm1d(hidden_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x_in, adj, idx, cum_nodes):
        # embedding layer
        x = self.embedding(x_in)

        # first message passing layer
        pad_mask = torch.nn.utils.rnn.pad_sequence([torch.zeros(e) for e in cum_nodes], batch_first=True, padding_value=1).type(torch.bool)
        x = self.fc1(x, src_key_padding_mask=pad_mask)
        x = torch.vstack([x[j,:c] for j,c in enumerate(cum_nodes)])
        x = self.relu(torch.mm(adj, x))
        x = self.dropout(x)

        # second message passing layer #voir si on le laisse
        # x = self.fc2(x)
        # x = self.relu(torch.mm(adj, x))
        
        # sum aggregator
        idx = idx.unsqueeze(1).repeat(1, x.size(1))
        out = torch.zeros(torch.max(idx)+1, x.size(1)).to(x_in.device)
        out = out.scatter_add_(0, idx, x)
        
        # batch normalization layer
        out = self.bn(out)

        # mlp to produce output
        out = self.relu(self.fc3(out))
        out = self.dropout(out)
        out = self.fc4(out)
        
        return F.log_softmax(out, dim=1)

model = GNN(86, 64, 0.2, 18).to(device)

adj_ = adj_train[indices]
adj_ = sp.block_diag(adj_)
adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

cum_nodes = adj_shapes_train[indices]
idx_batch = np.repeat(np.arange(
    len(indices)
    ), cum_nodes)
idx_batch = torch.LongTensor(idx_batch).to(device)

features_ = np.array(features_train)[indices]
features_ = torch.nn.utils.rnn.pad_sequence(
    [torch.tensor(f, dtype=torch.float32) for f in features_],
    batch_first=True
    ).to(device)

output = model(features_, adj_, idx_batch, cum_nodes)

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    # lr=1,
    lr=1e-3,
    # weight_decay=0.0001
    )
epochs = 50
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-4) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

shuffle_train = np.arange(len(sequences_train))
shuffle_valid = np.arange(len(sequences_valid))

batchSize = 64

patience = 50
patience_count = 0

previous_epoch_loss_valid = 300

res = list()

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)


    # Train
    # batchSize=2
    model.train()
    for i in range(0, len(sequences_train), batchSize):

        incr_i = min(i+batchSize, len(sequences_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        # features_ = np.array(features_train)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)

        features_ = np.array(features_train)[indices]
        features_ = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(f, dtype=torch.float32) for f in features_],
            batch_first=True
            ).to(device)

        cum_nodes = adj_shapes_train[indices]
        idx_batch = np.repeat(np.arange(
            batchSize if incr_i%batchSize==0 else incr_i%batchSize
            ), cum_nodes)
        idx_batch = torch.LongTensor(idx_batch).to(device)
        
        adj_ = adj_train[indices]
        adj_ = sp.block_diag(adj_)
        adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        optimizer.zero_grad()

        output = model(features_, adj_, idx_batch, cum_nodes)
        # output = model(features_, adj_, idx_batch)
        
        del features_
        del adj_
        del idx_batch

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        del output

    # Validation
    # batchSize=64
    # model.eval()
    # for i in range(0, len(sequences_valid), int(batchSize)):

    #     incr_i = min(i+int(batchSize), len(sequences_valid))
    #     indices = shuffle_valid[i: incr_i]

    #     targets_ = torch.tensor(y_valid[indices])
    #     targets_ = F.one_hot(targets_, 18).float().to(device)

    #     # features_ = np.array(features_train)[indices]
    #     # features_ = np.vstack(features_)
    #     # features_ = torch.FloatTensor(features_).to(device)
        
    #     features_ = np.array(features_train)[indices]
    #     features_ = torch.nn.utils.rnn.pad_sequence(
    #         [torch.tensor(f, dtype=torch.float32) for f in features_],
    #         batch_first=True
    #         ).to(device)

    #     cum_nodes = adj_shapes_valid[indices]
    #     idx_batch = np.repeat(np.arange(
    #         int(batchSize) if incr_i%int(batchSize)==0 else incr_i%int(batchSize)
    #         ), cum_nodes)
    #     idx_batch = torch.LongTensor(idx_batch).to(device)
        
    #     adj_ = adj_valid[indices]
    #     adj_ = sp.block_diag(adj_)
    #     adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    #     output = model(features_, adj_, idx_batch, cum_nodes)

    #     del features_
    #     del adj_
    #     del idx_batch

    #     # Compute metrics
    #     loss = criterion(output, targets_)
    #     epoch_loss_valid.append(loss.item())
        
    #     targets_ = targets_.to('cpu').detach().numpy()
    #     output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

    #     accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
    #     epoch_accuracy_valid.append(accuracy)

    #     del output

    # # scheduler.step()

    # epoch_loss = np.array(epoch_loss).mean()
    # epoch_loss_valid = np.array(epoch_loss_valid).mean()
    # epoch_accuracy = np.array(epoch_accuracy).mean()
    # epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    # tqdm.write(
    #     '\n'
    #     f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
    #     f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
    #     f'acc valid: {round(epoch_accuracy_valid, 4)}'
    #     )
    
    # # Early stopping :
    # if previous_epoch_loss_valid < epoch_loss_valid:
    #     if patience_count == 0:
    #       torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
    #     patience_count +=1
    #     if patience_count == patience:
    #       print(f'Early stopping end of epoch {epoch}')
    #       break

    # else :
    #     previous_epoch_loss_valid = epoch_loss_valid
    #     patience_count = 0

    #     # Save last best results
    #     del res
    #     res = list()
    #     shuffle_test = np.arange(len(sequences_test))

    #     for i in range(0, len(sequences_test), int(batchSize)):
    #         model.eval()

    #         incr_i = min(i+int(batchSize), len(sequences_test))
    #         indices = shuffle_test[i: incr_i]

    #         # features_ = np.array(features_test)[indices]
    #         # features_ = np.vstack(features_)
    #         # features_ = torch.FloatTensor(features_).to(device)

    #         features_ = np.array(features_test)[indices]
    #         features_ = torch.nn.utils.rnn.pad_sequence(
    #             [torch.tensor(f, dtype=torch.float32) for f in features_],
    #             batch_first=True
    #             ).to(device)

    #         cum_nodes = adj_shapes_test[indices]
    #         idx_batch = np.repeat(np.arange(
    #             batchSize if incr_i%batchSize==0 else incr_i%batchSize
    #             ), cum_nodes)
    #         idx_batch = torch.LongTensor(idx_batch).to(device)
            
    #         adj_ = adj_test[indices]
    #         adj_ = sp.block_diag(adj_)
    #         adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

    #         # output = model(features_, adj_, idx_batch)
    #         output = model(features_, adj_, idx_batch, cum_nodes)

    #         res.append(output.to('cpu').detach().numpy())

    #         del output

# MULTICULTURAL MODEL

In [7]:
# class ConcatModel(nn.Module):
#     def __init__(self, out_dim, feature_dim, graph_dim):
#         super(ConcatModel, self).__init__()
#         # self.seqModel = ProteinClassifier(out_dim).to(device)
#         # self.graphModel = GNN(feature_dim, graph_dim, dropout=0.25, n_class=out_dim).to(device)

#         self.seqModel = ProteinClassifier(18).to(device)

#         # self.classifier = nn.Linear(out_dim, 18)
#         # self.activation = nn.LogSoftmax(dim=1)

#     def forward(self, x_in, adj, idx, input_ids, attention_mask):
#         seqEmbedding = self.seqModel(input_ids, attention_mask)
#         # graphEmbedding = self.graphModel(x_in, adj, idx)
        
#         # out = torch.concat([graphEmbedding, seqEmbedding], dim=1)
        
#         # out = self.classifier(out)
        
#         # return self.activation(out)
#         return nn.LogSoftmax(dim=1)(seqEmbedding)

# model = ConcatModel(64, 86, 128).to(device)
model = ProteinClassifier(18, device='cpu')

Some weights of the model checkpoint at yarongef/DistilProtBert were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at yarongef/DistilProtBert and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bi

In [None]:
for module in model.bert.encoder.layer[:-1]:
    for param in module.parameters():
        param.requires_grad = False

for module in model.bert.encoder.layer[-1:]:
    # module._modules["output"].dense.weight.data.normal_(mean=0.0, std=0.1)
    module._modules["output"].dense.weight.data.uniform_(-0.1, 0.1)

In [None]:
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('yarongef/DistilProtBert', do_lower_case=False)

encode = lambda s : tokenizer.encode_plus(
            s,
            truncation=True,
            add_special_tokens=True,
            max_length=600,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )

In [None]:
optimizer = torch.optim.Adam(
    model.parameters(), 
    # lr=1,
    lr=3e-5,
    # weight_decay=0.0001
    )
epochs = 20
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, 
#                 lr_lambda=lambda epoch: 1e-6 + (1e-6 - 1e-4) * ((epoch - epochs)/ epochs))
criterion = nn.CrossEntropyLoss()

shuffle_train = np.arange(len(sequences_train))
shuffle_valid = np.arange(len(sequences_valid))

batchSize = 4

patience = 2
patience_count = 0

previous_epoch_loss_valid = 300

res = list()

pbar = tqdm(range(epochs))
for epoch in pbar:
    epoch_loss, epoch_loss_valid, epoch_accuracy, epoch_accuracy_valid = [], [], [], []

    np.random.shuffle(shuffle_train)
    np.random.shuffle(shuffle_valid)


    # Train
    batchSize=2
    model.train()
    for i in range(0, len(sequences_train), batchSize):

        incr_i = min(i+batchSize, len(sequences_train))
        indices = shuffle_train[i: incr_i]

        targets_ = torch.tensor(y_train[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_train[indices]]).to(device)

        attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_train[indices]]).to(device)

        # features_ = np.array(features_train)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)

        # cum_nodes = adj_shapes_train[indices]
        # idx_batch = np.repeat(np.arange(
        #     batchSize if incr_i%batchSize==0 else incr_i%batchSize
        #     ), cum_nodes)
        # idx_batch = torch.LongTensor(idx_batch).to(device)
        
        # adj_ = adj_train[indices]
        # adj_ = sp.block_diag(adj_)
        # adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        optimizer.zero_grad()
        output = model(sequences_, attention_mask_)
        
        # del features_
        # del adj_
        # del idx_batch
        del sequences_
        del attention_mask_

        # Backward
        loss = criterion(output, targets_)
        loss.backward()
        optimizer.step()

        # Compute metrics
        epoch_loss.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy.append(accuracy)

        del output

    # Validation
    batchSize=4
    model.eval()
    for i in range(0, len(sequences_valid), int(batchSize/4)):

        incr_i = min(i+int(batchSize/4), len(sequences_valid))
        indices = shuffle_valid[i: incr_i]

        targets_ = torch.tensor(y_valid[indices])
        targets_ = F.one_hot(targets_, 18).float().to(device)
        
        sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_valid[indices]]).to(device)

        attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_valid[indices]]).to(device)

        # features_ = np.array(features_valid)[indices]
        # features_ = np.vstack(features_)
        # features_ = torch.FloatTensor(features_).to(device)

        # cum_nodes = adj_shapes_valid[indices]
        # idx_batch = np.repeat(np.arange(
        #     int(batchSize/4) if incr_i%int(batchSize/4)==0 else incr_i%int(batchSize/4)
        #     ), cum_nodes)
        # idx_batch = torch.LongTensor(idx_batch).to(device)
        
        # adj_ = adj_valid[indices]
        # adj_ = sp.block_diag(adj_)
        # adj_ = sparse_mx_to_torch_sparse_tensor(adj_).to(device)

        output = model(sequences_, attention_mask_)

        # del features_
        # del adj_
        # del idx_batch
        del sequences_
        del attention_mask_

        # Compute metrics
        loss = criterion(output, targets_)
        epoch_loss_valid.append(loss.item())
        
        targets_ = targets_.to('cpu').detach().numpy()
        output = nn.Softmax(dim=1)(output).to('cpu').detach().numpy()

        accuracy = (targets_.argmax(axis=1)==output.argmax(axis=1)).mean()
        epoch_accuracy_valid.append(accuracy)

        del output

    # scheduler.step()

    epoch_loss = np.array(epoch_loss).mean()
    epoch_loss_valid = np.array(epoch_loss_valid).mean()
    epoch_accuracy = np.array(epoch_accuracy).mean()
    epoch_accuracy_valid = np.array(epoch_accuracy_valid).mean()

    tqdm.write(
        '\n'
        f'Epoch {epoch}:\ntrain loss: {round(epoch_loss, 4)}, '
        f'valid loss: {round(epoch_loss_valid, 4)}, delta loss: {round(epoch_loss-epoch_loss_valid, 4)},\nacc train: {round(epoch_accuracy, 4)}, '
        f'acc valid: {round(epoch_accuracy_valid, 4)}'
        )
    
    # Early stopping :
    if previous_epoch_loss_valid < epoch_loss_valid:
        if patience_count == 0:
          torch.save({"state": model.state_dict(),}, "last_model_checkpoint.pt")
        patience_count +=1
        if patience_count == patience:
          print(f'Early stopping end of epoch {epoch}')
          break

    else :
        previous_epoch_loss_valid = epoch_loss_valid
        patience_count = 0

        # Save last best results
        del res
        res = list()
        shuffle_test = np.arange(len(sequences_test))

        for i in range(0, len(sequences_test), int(batchSize/2)):
            model.eval()

            incr_i = min(i+int(batchSize/2), len(sequences_test))
            sequences_ = torch.concat([encode(s)['input_ids'] for s in sequences_test[shuffle_test[i: incr_i]]])
            attention_mask_ = torch.concat([encode(s)['attention_mask'] for s in sequences_test[shuffle_test[i: incr_i]]])

            output = model(sequences_, attention_mask_)

            del sequences_
            del attention_mask_
            
            res.append(output.to('cpu').detach().numpy())

            del output

# SUBMISSION

In [None]:
sub = list(map(lambda x: np.array(nn.Softmax(dim=1)(torch.tensor(x))), res))

y_pred = np.concatenate(sub, axis=0)
y_pred.shape

proteins_test = list()
with open('data/graph_labels.txt', 'r') as f:
    for i,line in enumerate(f):
        t = line.split(',')
        if len(t[1][:-1]) == 0:
            proteins_test.append(t[0])

In [None]:
import csv

with open('8epochs_samedi_uniform_bs2_lr3e5.csv', 'w') as csvfile:
    writer = csv.writer(csvfile, delimiter=',')
    lst = list()
    for i in range(18):
        lst.append('class'+str(i))
    lst.insert(0, "name")
    writer.writerow(lst)
    for i, protein in enumerate(proteins_test):
        lst = y_pred[i,:].tolist()
        lst.insert(0, protein)
        writer.writerow(lst)

In [None]:
checkpoint = torch.load("../gdrive/MyDrive/data/last_model_checkpoint.pt")
model.load_state_dict(checkpoint['state'])
model.eval()
()

In [None]:
## TODO metrics par class

import matplotlib.pyplot as plt
import numpy as np

k = 0

résumé = []
for k in range(18):
    bol = tf.argmax(y_test, axis=1) == k
    preds = model(X_test)[bol]
    targets = y_test[bol]
    accuracy_preds = (tf.argmax(preds, axis=1)==k).numpy().mean()
    nll = log_loss(y_true=targets, y_pred=preds)
    résumé.append((accuracy_preds, nll))

acc = [tup[0] for tup in résumé]
nll = [tup[1] for tup in résumé]

fig, axs = plt.subplots(1, 3, tight_layout=True, figsize=(14, 6))
axs[0].bar(np.arange(len(acc)), acc)
axs[0].set_title('Accuracy')
plt.xticks(np.arange(len(acc)))
axs[1].bar(np.arange(len(nll)), nll)
axs[1].set_title('Negative Log Likelihood')
plt.xticks(np.arange(len(acc)))
axs[2].bar(np.arange(len(nll)), [440.,  50., 939.,  60., 112., 625., 202.,  74., 998.,  57.,  43.,305.,  44.,  59., 548., 226.,  60.,  46.])
axs[2].set_title('Pops')
plt.xticks(np.arange(len(acc)))
plt.show()