In [None]:
import pandas as pd

import torch
from torch import optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

from sklearn.utils.class_weight import compute_class_weight
import seaborn as sns
from common import *

#### Load data

In [None]:
train_data = pd.read_csv("../../data/ld50/train.csv")
test_data = pd.read_csv("../../data/ld50/test.csv")

Y_train = train_data["Class"]
Y_test = test_data["Class"]

X_train = pd.read_csv("train_embeddings.csv")
X_test = pd.read_csv("test_embeddings.csv")

X_train.describe(), Y_train.describe()

counts = Y_test.value_counts().sort_index()
counts.index = class_labels = ["Alto", "Moderado", "Leve", "Desprezível"][::-1]

sns.barplot(counts)

#### Load model and extend layers

In [None]:
head = torch.nn.Sequential(
            torch.nn.Linear(768, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(1024, 1024),
            torch.nn.BatchNorm1d(1024),
            torch.nn.LeakyReLU(),
            torch.nn.Linear(1024, 4)
        )

#### Full model training

In [None]:
class SmilesDataset(Dataset):
    def __init__(self, x: pd.DataFrame, y: pd.Series):
        self.X = x
        self.Y = y

    def __len__(self):
        return len(self.Y)
    
    def __getitem__(self, index: int):
        x = torch.tensor(self.X.iloc[index].values).float()
        y = torch.tensor(self.Y.iloc[index]).long()
        return x, y

train_dataset = SmilesDataset(X_train, Y_train)
test_dataset = SmilesDataset(X_test, Y_test)
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=True)

In [None]:
class_weights = compute_class_weight(class_weight='balanced', classes=Y_train.unique(), y=Y_train.values)

optimizer = optim.Adam(head.parameters(), lr=12e-5)
criterion = torch.nn.CrossEntropyLoss(weight=torch.tensor(class_weights).float())
num_epochs = 200

writer = SummaryWriter()

for epoch_index in range(num_epochs):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(train_dataloader):
        if i == 0:
            running_loss = 0.
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = head(inputs)

        # Compute the loss and its gradients
        loss = criterion(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

    train_loss = running_loss / len(train_dataloader) # loss per batch
    writer.add_scalar('Train/loss', train_loss)

    test_loss = 0.
    head.eval()
    with torch.no_grad():
        for data in test_dataloader:
            inputs, labels = data
            outputs = head(inputs)
            loss = criterion(outputs, labels.float().unsqueeze(-1))
            test_loss += loss.item()
    head.train()
    test_loss /= len(test_dataloader)
    writer.add_scalar('Test/loss', train_loss)

In [None]:
# 5min13s
from datetime import datetime
torch.save(head.state_dict(), f"nn_model_classification_{datetime.now().isoformat()}")

In [None]:
inputs = torch.tensor(X_test.values).float()
outputs = head(inputs)
pred = torch.nn.Softmax()(outputs)
pred = pred.argmax(axis=1)

import numpy as np
import seaborn as sns
from sklearn.metrics import r2_score
matrix = np.zeros((4, 4), dtype=np.float64)
for i, (p, t) in enumerate(zip(pred, Y_test)):
    matrix[t, p] += 1

annot = np.copy(matrix)
for i, count in enumerate(Y_test.value_counts().sort_index()):
    matrix[i] /= count

sns.heatmap(matrix, cmap='coolwarm', robust=True, annot=annot, fmt='g', xticklabels=class_labels, yticklabels=class_labels).set_title("Matriz de confusão")

from sklearn.metrics import classification_report
print(classification_report(Y_test.values, pred))