# ML Hands-on Challenge - Data preparation and exploration

This notebook contains the embedding model for the ML Hands-on Challenge. We will use [EsmModel](https://github.com/facebookresearch/esm) for gettings these embeddings.


## Loading things

In [1]:
import pandas as pd
import numpy as np
import sys
import glob
import os
import torch

# Get parent directory and add src/ to pythonpath
parent_dir = os.path.dirname(os.getcwd())
sys.path.append(parent_dir)

from src.config import CONSTANTS

DATA_HOME = os.path.join(parent_dir, CONSTANTS.DATA_HOME)
SEED = CONSTANTS.SEED
emb_model_tag = CONSTANTS.EMBEDDING_MODEL
architecture_names = CONSTANTS.ARCHITECTURE_NAMES
batch_size = CONSTANTS.BATCH_SIZE

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading processed data from the [`data_prep_exploration.ipynb`](./data_prep_exploration.ipynb) notebook

In [2]:
# Open the training data sequences and structure
proc_data = pd.read_csv(f"{DATA_HOME}/data_processed.csv", index_col=0)

# Converting `piece_edges` and `domain_edges` into a proper format
proc_data = proc_data.assign(
    **{
        "piece_edges": proc_data["piece_edges"].apply(eval),
        "domain_edges": proc_data["domain_edges"].apply(eval),
    }
)

# Converting `target` to string format
proc_data["target"] = proc_data["target"].astype(str)
display(proc_data)

# Get the schema of the data
proc_data.info()

Unnamed: 0,cath_id,pdb_id,class,architecture,topology,superfamily,resolution_in_angstroms,sequence,piece_edges,num_residues,domain_edges,total_num_residues,num_gaps,target
0,2w3sB01,2w3s,3,90,1170,50,2.60,SVGKPLPHDSARAHVTGQARYLDDLPCPANTLHLAFGLSTEASAAI...,"[(2, 124)]",123,"[2, 124]",123,0,3.9
1,3be3A00,3be3,2,30,30,320,2.04,QDFRPGVYRHYKGDHYLALGLARADETDEVVVVYTRLYARAGLPST...,"[(6, 49), (51, 81)]",75,"[6, 81]",76,1,2.3
2,3zq4C03,3zq4,3,10,20,580,3.00,DIGNIVLRDRRILSEEGLVIVVVSIDMDDFKISAGPDLISRGFVIN...,"[(449, 492), (501, 555)]",99,"[449, 555]",107,1,3.1
3,1peqA03,1peq,1,10,1650,20,2.80,DITFRLAKENAQMALFSPYDIQRRYGKPFGDIAISERYDELIADPH...,"[(294, 346)]",53,"[294, 346]",53,0,1.1
4,1bdoA00,1bdo,2,40,50,100,1.80,EISGHIVRSPMVGTFYRTPSPDAKAFIEVGQKVNVGDTLCIVEAMK...,"[(77, 156)]",80,"[77, 156]",80,0,2.4
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6263,2yyiA02,2yyi,2,40,110,10,1.66,ATTHALTNPQVNRARPPSGQPDPYIPVGVVKQTEKGIVVRGARMTA...,"[(139, 196), (199, 266)]",126,"[139, 266]",128,1,2.4
6264,4mo0A00,4mo0,3,30,780,10,2.10,EQKIKIYVTKRRFGKLMTIIEGFDTSVIDLKELAKKLKDICACGGT...,"[(24, 102)]",79,"[24, 102]",79,0,3.3
6265,1vq8X00,1vq8,3,10,440,10,2.20,ERVVTIPLRDARAEPNHKRADKAMILIREHLAKHFSVDEDAVRLDP...,"[(7, 88)]",82,"[7, 88]",82,0,3.1
6266,1ze3D00,1ze3,3,10,20,410,1.84,DLYFNPRFLLSRFENGQELPPGTYRVDIYLNNGYMATRDVTFNTGD...,"[(1, 9), (19, 125)]",116,"[1, 125]",125,1,3.1


<class 'pandas.core.frame.DataFrame'>
Index: 6268 entries, 0 to 6267
Data columns (total 14 columns):
 #   Column                   Non-Null Count  Dtype  
---  ------                   --------------  -----  
 0   cath_id                  6268 non-null   object 
 1   pdb_id                   6268 non-null   object 
 2   class                    6268 non-null   int64  
 3   architecture             6268 non-null   int64  
 4   topology                 6268 non-null   int64  
 5   superfamily              6268 non-null   int64  
 6   resolution_in_angstroms  6268 non-null   float64
 7   sequence                 6268 non-null   object 
 8   piece_edges              6268 non-null   object 
 9   num_residues             6268 non-null   int64  
 10  domain_edges             6268 non-null   object 
 11  total_num_residues       6268 non-null   int64  
 12  num_gaps                 6268 non-null   int64  
 13  target                   6268 non-null   object 
dtypes: float64(1), int64(7), obje

## Load the ESM model, try in on a sequence

In [3]:
from transformers import AutoTokenizer, EsmModel

test_sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"

tokenizer = AutoTokenizer.from_pretrained(emb_model_tag)
embedding_model = EsmModel.from_pretrained(emb_model_tag).to(device)

inputs = tokenizer(test_sequence, return_tensors="pt").to(device)
with torch.no_grad():
    # Move inputs to cuda
    outputs = embedding_model(**inputs)

last_hidden_state = outputs.last_hidden_state
print(last_hidden_state)
print(last_hidden_state.shape)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of EsmModel were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['esm.pooler.dense.bias', 'esm.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


tensor([[[ 0.2231,  0.5661,  0.1139,  ...,  1.0212, -0.1900, -0.6870],
         [ 0.4873,  0.2405, -0.1978,  ...,  0.6398, -0.0806, -0.3449],
         [-0.1065, -0.3528, -0.1022,  ..., -0.1548,  0.2464,  0.0080],
         ...,
         [-0.2542, -0.3260,  0.6081,  ...,  0.2127, -0.1515, -0.6503],
         [-0.0516, -0.1907,  0.3541,  ..., -0.0673, -0.0118, -0.5485],
         [-0.1355,  0.0183,  0.1578,  ...,  0.2765, -0.7183, -0.4105]]],
       device='cuda:0')
torch.Size([1, 67, 320])


We can observe the shape of the last hidden state:

```
torch.Size([1, 67, 320])  # (batch_size, sequence_length, hidden_size)
```

Let's apply the `EsmModel` to all the sequences that we have.

## Creating sequence embeddings

In [5]:
from src.embedding import embed_sequences

# Preprocess sequences and create embeddings
sequences = proc_data["sequence"].tolist()
embeddings = (
    embed_sequences(
        tokenizer, embedding_model, sequences, batch_size=batch_size, device=device
    )
    .cpu()
    .numpy()
)

# Print the shape of embeddings
print(embeddings)
print(embeddings.shape)

[[ 0.06451418  0.6619094   0.06489832 ...  1.0157654  -0.11351484
  -0.5050586 ]
 [ 0.0648082   0.7021824   0.15627454 ...  0.8986866  -0.15614237
  -0.5502283 ]
 [ 0.2197142   0.7074016   0.16096805 ...  0.82776386 -0.1777948
  -0.38856116]
 ...
 [ 0.132289    0.7346928  -0.04843182 ...  1.1155903  -0.09529998
  -0.5828432 ]
 [ 0.14722496  0.94436026  0.23472166 ...  0.9659419  -0.1346486
  -0.55718654]
 [-0.14243062  0.74474734  0.26580694 ...  0.9322149  -0.19289963
  -0.6503325 ]]
(6268, 320)


## Add embeddings to the original pandas dataframe

In [6]:
master_df = pd.concat(
    [
        proc_data,
        pd.DataFrame(
            embeddings,
            index=proc_data.index,
            columns=[f"emb_{i}" for i in range(embeddings.shape[1])],
        ),
    ],
    axis=1,
)

## Split data into train and test

We are ready to split the data into train and test sets. We will use 80% of the data for training and 20% for testing.

Based on the comment:
> One common challenge when working with protein data is that some sequences and
structures may originate from evolutionarily related organisms (also called homologous proteins). This can cause issues because related proteins can be really similar to each other, so when a model is evaluated on a protein that is related to one that it has been trained on, it will likely be more accurate because of data leakage. Thus, we will want to account when splitting data, to make sure we are fairly assessing the model performance and generalizability.
> - For the dataset that we are working with, proteins that are in the **same Homologous superfamily (H) level** in the CATH hierarchy are **related**, so you will want to split the dataset to account for this.

We will have to stratify the split based on the `superfamily` column. Let us double check the distribution of the `superfamily` column.

In [7]:
master_df["superfamily"].value_counts()

superfamily
10       1799
20        401
30        291
40        196
140       157
         ... 
12600       1
12710       1
4820        1
11810       1
2490        1
Name: count, Length: 604, dtype: int64

We can notice that some classes are underrepresented, i.e. there is only one sample per `superfamily` class. We will put them all in the test set.

In [8]:
from sklearn.model_selection import train_test_split

SPLIT = 0.2  # 20% of the data will be used for testing
VAL_SPLIT = 0.1  # 10% of the data will be used for validation

# Identify classes in 'superfamily' with only one sample
single_sample_classes = (
    master_df["superfamily"].value_counts().loc[lambda x: x == 1].index
)

# Separate these into a different DataFrame
single_sample_df = master_df[master_df["superfamily"].isin(single_sample_classes)]

# Remove these samples from the original DataFrame
trunc_master_df = master_df[~master_df["superfamily"].isin(single_sample_classes)]

# Perform stratified train-test split on the remaining data
train_master, test_master = train_test_split(
    trunc_master_df,
    test_size=SPLIT,
    random_state=SEED,
    stratify=trunc_master_df["superfamily"],
)

# Further split the training data into train and validation sets
train_master, val_master = train_test_split(
    train_master,
    test_size=VAL_SPLIT,
    random_state=SEED,
    stratify=train_master["superfamily"],
)

# Add the single sample classes to the test set
test_master = pd.concat([test_master, single_sample_df])

print("Train set:")
display(train_master)

print("Val set:")
display(val_master)

print("Test set:")
display(test_master)

Train set:


Unnamed: 0,cath_id,pdb_id,class,architecture,topology,superfamily,resolution_in_angstroms,sequence,piece_edges,num_residues,...,emb_310,emb_311,emb_312,emb_313,emb_314,emb_315,emb_316,emb_317,emb_318,emb_319
5897,3rrkA03,3rrk,3,30,70,2750,2.64,GLDESPRLGVIPFLVAKPEELEAVRKALQEALADRFVLEAEPLENQ...,"[(128, 201)]",74,...,-0.125288,0.286744,0.609983,-0.140890,0.345590,0.207203,0.454774,0.923088,-0.162230,-0.418057
2319,2a90A01,2a90,3,30,720,50,2.15,AHAVSVFYAPSSPAGKGTKWEWSGGSADSNNDWRPYNHVQSIIEDA...,"[(43, 48), (120, 150), (152, 195), (197, 211)]",96,...,-0.293845,0.112523,0.597217,-0.068405,0.274758,0.142621,0.177390,0.995879,-0.128748,-0.511798
181,1z2zA02,1z2z,3,30,70,3160,2.60,NREEGEEGKYLIVELTKRDWDTHHLTRTLSRILQVSQKRISVAGTK...,"[(40, 129)]",90,...,-0.153192,0.189229,0.733223,-0.078219,0.430919,0.155335,0.384786,0.929161,-0.144540,-0.312290
689,2xtuA00,2xtu,1,20,1540,10,1.85,ERAGPVTWVMMIACVVVFIAMQILGDQEVMLWLAWPFDPTLKFEFW...,"[(91, 271)]",181,...,-0.690923,0.006688,0.719441,-0.019133,0.687013,0.156064,0.226512,1.149459,-0.114837,-0.374691
5170,3hbxA03,3hbx,3,90,1150,160,2.67,TERFNIVSKDEGVPLVAFSLKDSSCHTEFEISDMLRRYGWIVPAYT...,"[(361, 448)]",88,...,-0.262856,0.083803,0.695856,-0.018524,0.495029,0.115150,0.376047,0.968112,-0.185106,-0.456067
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2201,2p5zX04,2p5z,2,40,50,230,2.60,TLPARVTSDIYAHIDKDGRYRVNLDFRDTWKPGYESLWVRLLAGTE...,"[(384, 391), (396, 413), (415, 428), (444, 468)]",65,...,-0.368285,0.204450,0.630889,-0.125225,0.343566,0.024804,0.250409,0.965241,-0.137626,-0.510485
2970,1wlyA01,1wly,3,90,180,10,1.30,VMAAVIHKKGGPDNFVWEEVKVGSPGPGQVRLRNTAIGVNFLDTYH...,"[(2, 49), (60, 128), (276, 333)]",175,...,-0.265764,0.092823,0.544944,0.006851,0.295354,0.215891,0.182885,0.935894,-0.041790,-0.450210
990,2xgjA05,2xgj,2,40,30,300,2.90,QPGRLVEISVNGKDNYGWGAVVDFAKRINKRNPSAVYTDHESYIVN...,"[(675, 804)]",130,...,-0.573048,0.270214,0.777427,-0.268919,0.723553,0.148158,0.105918,1.182733,-0.118003,-0.402654
961,2y7bA00,2y7b,2,30,29,30,1.90,VNSSVEERGFLTIFEDVSGFGAWHRRWCVLSGNCISYWTYPDDEKR...,"[(980, 1113)]",134,...,-0.141735,-0.045978,0.683498,0.104705,0.251053,0.184552,0.416486,0.866533,-0.167048,-0.498696


Val set:


Unnamed: 0,cath_id,pdb_id,class,architecture,topology,superfamily,resolution_in_angstroms,sequence,piece_edges,num_residues,...,emb_310,emb_311,emb_312,emb_313,emb_314,emb_315,emb_316,emb_317,emb_318,emb_319
4502,4q6jA00,4q6j,3,20,20,450,1.37,DISSTEIWDAIRRNSYLLYYQPKVDAKTNKIIGFEGLVRLKTATTI...,"[(2, 65), (67, 101), (103, 126), (130, 150), (...",238,...,-0.157351,0.069194,0.956312,-0.055715,0.596132,0.255238,0.279304,0.846944,0.000283,-0.489018
1098,2aegA02,2aeg,3,90,1680,20,2.30,EDKDWVSKWAQDAESLINLPAYQNPDQGPIVRNTADGKKQLVHARW...,"[(8, 26), (28, 31), (33, 36), (38, 59), (92, 1...",135,...,-0.238793,0.156936,0.666001,-0.082238,0.448866,0.060233,0.286238,1.014989,-0.190741,-0.653920
3330,5jx5A00,5jx5,3,20,20,40,1.80,TSDNFFENELYSNYKFQGEVDQSIQRLSGSLQEKAKKVKYVPTAAW...,"[(128, 449)]",322,...,-0.342259,0.037232,0.623547,-0.118052,0.410517,0.102899,0.151461,0.922462,-0.121846,-0.249787
986,2xurA02,2xur,3,10,150,10,1.90,SEVEFTLPQATMKRLIEATQFSMAHQDVRYYLNCMLFETEGEELRT...,"[(124, 247)]",124,...,-0.201360,0.169413,0.675132,-0.031235,0.357389,0.058143,0.334610,0.964583,-0.027345,-0.460087
5684,4jhmA02,4jhm,3,20,20,120,2.80,PLYLLGGARTKIKSYASTPFDTVEEYFPYIDDCIEHGFTAIKLHCY...,"[(120, 122), (124, 139), (141, 191), (193, 204...",248,...,-0.427229,0.020816,0.621137,-0.257110,0.515372,0.155890,0.127954,0.919916,-0.090994,-0.470683
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1192,4kxwA03,4kxw,3,40,50,920,0.97,GQAKVVLKSKDDQVTVIGAGVTLHEALAAAELLKKEKINIRVLDPF...,"[(490, 622)]",133,...,-0.255454,0.238175,0.471215,-0.247368,0.445484,0.093481,0.209289,0.879157,-0.262471,-0.357759
546,3pc3A03,3pc3,3,10,580,10,1.55,EPVNEHGHWWWSLAIAELELPAPPVILKSDATVGEAIALMKKHRVD...,"[(367, 510)]",144,...,0.005830,0.349781,0.460956,-0.142876,0.393321,0.297320,0.300077,1.313645,-0.184660,-0.630027
4181,1hw7A02,1hw7,1,10,287,480,2.20,AQNAQQDDFDHLATLTETIKTEELLTLPANEVLWRLYHEEEVTVYD...,"[(179, 228)]",50,...,-0.357700,0.141953,0.636461,-0.182168,0.442599,0.086261,0.319394,0.921757,-0.200529,-0.576125
5132,2apoA01,2apo,2,30,130,10,1.95,ELIVKEEVETNWDYGCKKVVVKDSAVDAICHGADVYVRGIAKLSKG...,"[(17, 32), (249, 331)]",99,...,-0.324399,0.306441,0.553460,-0.121175,0.379777,0.019876,0.063452,0.924122,-0.175774,-0.309278


Test set:


Unnamed: 0,cath_id,pdb_id,class,architecture,topology,superfamily,resolution_in_angstroms,sequence,piece_edges,num_residues,...,emb_310,emb_311,emb_312,emb_313,emb_314,emb_315,emb_316,emb_317,emb_318,emb_319
4574,4xhyA00,4xhy,2,30,110,10,1.53,EITFHPAARLLREALGRFATGVTVVTTAGPQGPLGMTVNSFSSVSL...,"[(15, 21), (25, 178)]",161,...,-0.315072,0.296664,0.414774,-0.195118,0.379154,0.028909,0.240505,1.023734,-0.183559,-0.585804
3947,1lmlA04,1lml,2,30,34,10,1.86,YSDGSCTQRASEAHASLLPFNVFSDAARCIDGAFRPKASYAGLCAN...,"[(461, 498), (505, 565)]",99,...,-0.254015,0.094303,0.496386,-0.174896,0.162396,0.093555,0.023665,0.826625,-0.094012,-0.355415
1720,1dmgA00,1dmg,3,40,1370,10,1.70,AQVDLLNVKGEKVGTLEISDFVFNIDPNYDVMWRYVDMQLSDWSKK...,"[(2, 42), (96, 226)]",172,...,-0.192167,0.227641,0.807497,-0.090866,0.485441,0.251523,0.416560,0.690793,-0.154467,-0.390929
3814,2i71A02,2i71,1,10,3740,10,1.70,FNKIFSDSVNAIPRFATALDNGLFIYLSEKDSSLHLKRLEDDLSKD...,"[(205, 249), (266, 266), (268, 318), (320, 329...",145,...,-0.155918,0.122748,0.677432,-0.122079,0.508399,0.049267,0.272791,0.760193,-0.186250,-0.458647
3823,2evrA01,2evr,2,30,30,40,1.60,KLGEYQCLADLNLFDSPECTRLATQSASGRHLWVTSNHQNLAVEVY...,"[(13, 86)]",74,...,-0.268060,0.082375,0.716748,-0.005333,0.411968,0.180418,0.353511,1.080868,-0.210659,-0.676763
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6143,3uh8A00,3uh8,2,60,40,3350,2.30,MTEHFITLSTTEPNNNIGIVKLRHADVNSQAIVAQIVENGQPKNFE...,"[(1, 58), (64, 123)]",118,...,-0.285001,0.075679,0.770287,0.095357,0.513960,-0.032139,0.231449,1.014048,-0.321017,-0.445138
6165,2vdwB02,2vdw,3,40,50,11680,2.70,PESDLDKVYEILKINSVKYYGRSTKADAVVADLSARNKLFKRERDA...,"[(71, 119), (124, 201)]",127,...,-0.285426,0.077763,0.739526,-0.277646,0.415413,0.128815,0.213480,0.904471,-0.106010,-0.303624
6191,3lzdA03,3lzd,3,40,50,11860,2.10,PERFIRKRWAQIAKAMDAKKFGVIVSIKKGQLRLAEAKRIVKLLKK...,"[(211, 312)]",102,...,-0.185348,0.142605,0.657943,-0.250432,0.489998,0.056253,0.422040,1.062361,-0.160836,-0.374179
6249,4dmzA02,4dmz,3,30,70,2880,2.10,IDAQRFSQYLKRSLLDARDHGLPACLYAFELTDARYGEEVQRLLEG...,"[(319, 455)]",137,...,-0.280054,0.115678,0.574847,-0.086957,0.426436,0.134081,0.413892,0.980684,-0.203086,-0.574839


Done!

## Create `X` and `y` arrays

In [9]:
# Create X and y arrays from train_master
X_train = train_master.filter(regex="^emb_").to_numpy()
y_train = train_master["target"].to_numpy()

# Create X and y arrays from val_master
X_val = val_master.filter(regex="^emb_").to_numpy()
y_val = val_master["target"].to_numpy()

# Create X and y arrays from test_master
X_test = test_master.filter(regex="^emb_").to_numpy()
y_test = test_master["target"].to_numpy()

# Train set
print(f"Train set shape: {X_train.shape}")

# Validation set
print(f"Validation set shape: {X_val.shape}")

# Test set
print(f"Test set shape: {X_test.shape}")

Train set shape: (4350, 320)
Validation set shape: (484, 320)
Test set shape: (1434, 320)


## Load the model, define the training configuration

In [11]:
from src.config import ProteinClassifier
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import OneHotEncoder

encoder = OneHotEncoder(sparse_output=False)

# One-hot encode y_train, y_val, and y_test
y_train_encoded = encoder.fit_transform(y_train.reshape(-1, 1))
y_val_encoded = encoder.transform(y_val.reshape(-1, 1))
y_test_encoded = encoder.transform(y_test.reshape(-1, 1))

# Convert data into PyTorch DataLoader
train_dataset = TensorDataset(
    torch.tensor(X_train).to(device), torch.tensor(y_train_encoded).to(device)
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

val_dataset = TensorDataset(
    torch.tensor(X_val).to(device), torch.tensor(y_val_encoded).to(device)
)
val_loader = DataLoader(val_dataset, batch_size=batch_size)

test_dataset = TensorDataset(
    torch.tensor(X_test).to(device), torch.tensor(y_test_encoded).to(device)
)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

# Define hyperparameters
input_dim = X_train.shape[1]
hidden_dim = CONSTANTS.HIDDEN_DIM
output_dim = len(encoder.categories_[0])  # Number of unique classes
learning_rate = 0.001
epochs = 30

# Define the model, loss function, and optimizer
model = ProteinClassifier(input_dim, hidden_dim, output_dim).to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

print(model)

ProteinClassifier(
  (fc1): Linear(in_features=320, out_features=128, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=128, out_features=10, bias=True)
  (softmax): Softmax(dim=1)
)


## Training

In [12]:
# Training loop
for epoch in range(epochs):
    model.train()
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, torch.argmax(labels, dim=1))
        loss.backward()
        optimizer.step()

    # Validation
    if (epoch + 1) % 2 == 0:  # Validate every 2 epochs
        model.eval()
        val_losses = []
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                val_loss = criterion(outputs, torch.argmax(labels, dim=1))
                val_losses.append(val_loss.item())

        avg_val_loss = np.mean(val_losses)
        print(f"Epoch [{epoch+1}/{epochs}], Validation Loss: {avg_val_loss}")

        # Save model checkpoint
        if (epoch + 1) % 10 == 0:  # Save model checkpoint every 10 epochs
            checkpoint_path = f"{DATA_HOME}/model_epoch_{epoch+1}.pt"
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Model checkpoint saved at {checkpoint_path}")

Epoch [2/30], Validation Loss: 2.077776126563549
Epoch [4/30], Validation Loss: 2.0248253867030144
Epoch [6/30], Validation Loss: 1.994204431772232
Epoch [8/30], Validation Loss: 1.9725118577480316
Epoch [10/30], Validation Loss: 1.9518269151449203
Model checkpoint saved at /root/protein-structure-classification/data/model_epoch_10.pt
Epoch [12/30], Validation Loss: 1.9586259052157402
Epoch [14/30], Validation Loss: 1.926658347249031
Epoch [16/30], Validation Loss: 1.9425735548138618
Epoch [18/30], Validation Loss: 1.9464656114578247
Epoch [20/30], Validation Loss: 1.9244306236505508
Model checkpoint saved at /root/protein-structure-classification/data/model_epoch_20.pt
Epoch [22/30], Validation Loss: 1.918099820613861
Epoch [24/30], Validation Loss: 1.9108771532773972
Epoch [26/30], Validation Loss: 1.9319684132933617
Epoch [28/30], Validation Loss: 1.918594814836979
Epoch [30/30], Validation Loss: 1.9222517386078835
Model checkpoint saved at /root/protein-structure-classification/dat

## Final test validation

In [22]:
from sklearn.metrics import classification_report

# Evaluate the final model on the test set
model.eval()
test_losses = []
predictions = []
true_labels = []

with torch.no_grad():
    for inputs, labels in test_loader:
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)
        predictions.extend(predicted.cpu().numpy())
        true_labels.extend(torch.argmax(labels, dim=1).cpu().numpy())
        test_loss = criterion(outputs, torch.argmax(labels, dim=1))
        test_losses.append(test_loss.item())

# Calculate average test loss
avg_test_loss = np.mean(test_losses)

# Generate classification report
report = classification_report(true_labels, predictions, target_names=encoder.categories_[0])

# Calculate accuracy
accuracy = sum(1 for x, y in zip(predictions, true_labels) if x == y) / len(predictions)
print(f"Final Test Loss: {avg_test_loss}")
print(f"Accuracy on Test Set: {100 * accuracy:.2f}%")
print("Classification Report:")
print(report)

Final Test Loss: 1.9218494680192735
Accuracy on Test Set: 54.18%
Classification Report:
              precision    recall  f1-score   support

         1.1       0.60      0.44      0.51       133
         1.2       0.52      0.70      0.60       122
         2.3       0.61      0.53      0.57       137
         2.4       0.40      0.57      0.47       131
         2.6       0.71      0.67      0.69       172
         3.1       0.49      0.31      0.38       131
         3.2       0.89      0.90      0.90       125
         3.3       0.33      0.47      0.38       147
         3.4       0.57      0.63      0.60       203
         3.9       0.36      0.14      0.20       133

    accuracy                           0.54      1434
   macro avg       0.55      0.54      0.53      1434
weighted avg       0.55      0.54      0.53      1434

