In [1]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None) # For Train
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None) # For eval

In [None]:
cath_info = cath_df.info()

In [None]:
ecod_info = ecod_df.info()

In [None]:
# Summarize shapes, classes, and a preview
cath_summary = {
    "shape": cath_df.shape,
    "unique_classes": cath_df[0].nunique(),
    "class_distribution": cath_df[0].value_counts(),
    "head": cath_df.head()
}

ecod_summary = {
    "shape": ecod_df.shape,
    "unique_classes": ecod_df[0].nunique(),
    "class_distribution": ecod_df[0].value_counts(),
    "head": ecod_df.head()
}

In [None]:
print(cath_summary)

In [None]:
print(ecod_summary)

In [None]:
cath_df.iloc[0]

In [None]:
for i in range(0,2):
    print(type(cath_df.iloc[0,i]))

In [None]:
# Plot class distribution for CATH
cath_class_counts = cath_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(cath_class_counts, bins=30, edgecolor='black')
plt.title('Class Frequency Distribution in CATH Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot class distribution for ECOD
ecod_class_counts = ecod_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(ecod_class_counts, bins=20, edgecolor='black')
plt.title('Class Frequency Distribution in ECOD Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()


### Visualizing Class Imbalance

Each protein in the dataset is associated with a structural class label (column 0), which identifies its 3D(1D in this case) shape category. Understanding how many proteins fall into each class is critical because the dataset is not balanced—some classes have many proteins, while others have very few.

To quantify this, we use `value_counts()` to compute the frequency of each class label and visualize the distribution with a histogram.

#### Why this is important:
During training, we will create pairs of proteins to determine structural similarity. If a particular class contains many proteins, it will generate significantly more pairs, which may bias the model toward frequently occurring classes. This can lead to overfitting and poor generalization, especially on underrepresented classes.

By plotting the class frequency distribution:
- We confirm whether class imbalance is present.
- We motivate the need for sampling strategies such as `WeightedRandomSampler` to correct for this imbalance during training.

This analysis is a key step in understanding the structure of the data and informing how we design the training process.


In [2]:
import sys
import numpy as np

print("Python version:", sys.version)
print("NumPy version:", np.__version__)


Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ]
NumPy version: 1.26.4


In [3]:
import pandas as pd
import torch
import importlib
import train
import dataset
import time
import os
importlib.reload(train)
importlib.reload(dataset)
import random

random.seed(42)


from train import train_model, test_model_on_ecod

[INFO] Using 12 CPU threads
[INFO] Using 12 CPU threads


In [4]:
# Load datasets
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1)
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None).dropna(axis=1)

print(f"CATH: {cath_df.shape}, ECOD: {ecod_df.shape}")
features,labels=None,None


CATH: (2685, 3923), ECOD: (761, 3923)


In [7]:
num_proteins = 2685

In [8]:
print("Expected pairs:",num_proteins*(num_proteins-1)//2)

Expected pairs: 3603270


In [5]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=32, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)bb

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h32_bs128_lr0.001_ep30_20250401-013606
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.63 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=32) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 3512.7006 | Val Loss: 248.7990 | ROC AUC: 0.985 | PR AUC: 0.612 | MCC: 0.287


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 3018.2108 | Val Loss: 185.9943 | ROC AUC: 0.986 | PR AUC: 0.611 | MCC: 0.327


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 2972.2143 | Val Loss: 205.3223 | ROC AUC: 0.984 | PR AUC: 0.608 | MCC: 0.307


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 2953.9332 | Val Loss: 198.4350 | ROC AUC: 0.985 | PR AUC: 0.637 | MCC: 0.318


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 2948.7918 | Val Loss: 169.8047 | ROC AUC: 0.985 | PR AUC: 0.612 | MCC: 0.333


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 2954.5920 | Val Loss: 209.0730 | ROC AUC: 0.984 | PR AUC: 0.617 | MCC: 0.310


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 2936.8068 | Val Loss: 204.3735 | ROC AUC: 0.986 | PR AUC: 0.635 | MCC: 0.317


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 2929.3975 | Val Loss: 216.1191 | ROC AUC: 0.986 | PR AUC: 0.622 | MCC: 0.311


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 2927.5631 | Val Loss: 182.3573 | ROC AUC: 0.987 | PR AUC: 0.620 | MCC: 0.329


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 2920.6008 | Val Loss: 258.2216 | ROC AUC: 0.986 | PR AUC: 0.613 | MCC: 0.288


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 2932.0531 | Val Loss: 228.9377 | ROC AUC: 0.986 | PR AUC: 0.615 | MCC: 0.301


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 2507.6669 | Val Loss: 154.7678 | ROC AUC: 0.984 | PR AUC: 0.623 | MCC: 0.344


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 2440.5497 | Val Loss: 180.3335 | ROC AUC: 0.985 | PR AUC: 0.623 | MCC: 0.328


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 2415.5204 | Val Loss: 187.7129 | ROC AUC: 0.986 | PR AUC: 0.617 | MCC: 0.321


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 2409.6902 | Val Loss: 166.7239 | ROC AUC: 0.985 | PR AUC: 0.624 | MCC: 0.328


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 2407.7017 | Val Loss: 127.6555 | ROC AUC: 0.985 | PR AUC: 0.629 | MCC: 0.368


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 2407.9084 | Val Loss: 149.8583 | ROC AUC: 0.986 | PR AUC: 0.638 | MCC: 0.356


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 2403.4713 | Val Loss: 181.1475 | ROC AUC: 0.985 | PR AUC: 0.607 | MCC: 0.325


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 2403.6819 | Val Loss: 150.1593 | ROC AUC: 0.983 | PR AUC: 0.603 | MCC: 0.341


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 2406.9587 | Val Loss: 145.2736 | ROC AUC: 0.984 | PR AUC: 0.608 | MCC: 0.349


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 2393.0361 | Val Loss: 160.5319 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.346


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 2397.4725 | Val Loss: 178.0464 | ROC AUC: 0.982 | PR AUC: 0.607 | MCC: 0.323


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 2100.5460 | Val Loss: 137.3884 | ROC AUC: 0.984 | PR AUC: 0.618 | MCC: 0.359


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 2027.3500 | Val Loss: 143.7809 | ROC AUC: 0.985 | PR AUC: 0.609 | MCC: 0.357


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 2012.4368 | Val Loss: 168.0254 | ROC AUC: 0.986 | PR AUC: 0.621 | MCC: 0.338


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 2013.3372 | Val Loss: 140.4234 | ROC AUC: 0.984 | PR AUC: 0.612 | MCC: 0.354


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 1996.3876 | Val Loss: 106.9720 | ROC AUC: 0.982 | PR AUC: 0.604 | MCC: 0.383


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 1989.0283 | Val Loss: 121.1449 | ROC AUC: 0.984 | PR AUC: 0.627 | MCC: 0.376


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 1993.8984 | Val Loss: 119.3949 | ROC AUC: 0.984 | PR AUC: 0.618 | MCC: 0.375


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 1989.6741 | Val Loss: 154.8932 | ROC AUC: 0.985 | PR AUC: 0.623 | MCC: 0.347


In [8]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=64, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h64_bs128_lr0.001_ep30_20250401-044602
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.90 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=64) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2790.1215 | Val Loss: 133.2154 | ROC AUC: 0.986 | PR AUC: 0.627 | MCC: 0.382


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 2286.8043 | Val Loss: 153.6702 | ROC AUC: 0.987 | PR AUC: 0.646 | MCC: 0.365


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 2253.7898 | Val Loss: 218.2447 | ROC AUC: 0.988 | PR AUC: 0.647 | MCC: 0.324


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 2223.7766 | Val Loss: 212.8820 | ROC AUC: 0.988 | PR AUC: 0.634 | MCC: 0.328


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 2223.5772 | Val Loss: 190.3462 | ROC AUC: 0.984 | PR AUC: 0.619 | MCC: 0.323


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 2211.1190 | Val Loss: 167.3094 | ROC AUC: 0.986 | PR AUC: 0.638 | MCC: 0.352


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 2208.8373 | Val Loss: 214.9265 | ROC AUC: 0.987 | PR AUC: 0.636 | MCC: 0.319


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 1772.2714 | Val Loss: 145.9266 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.373


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 1710.9974 | Val Loss: 129.5568 | ROC AUC: 0.987 | PR AUC: 0.641 | MCC: 0.387


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 1691.6539 | Val Loss: 139.3129 | ROC AUC: 0.987 | PR AUC: 0.629 | MCC: 0.375


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 1696.8884 | Val Loss: 179.0632 | ROC AUC: 0.988 | PR AUC: 0.644 | MCC: 0.353


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 1692.1973 | Val Loss: 90.8814 | ROC AUC: 0.982 | PR AUC: 0.610 | MCC: 0.417


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 1686.4800 | Val Loss: 177.0457 | ROC AUC: 0.986 | PR AUC: 0.636 | MCC: 0.347


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 1679.1960 | Val Loss: 131.7595 | ROC AUC: 0.989 | PR AUC: 0.638 | MCC: 0.390


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 1678.4261 | Val Loss: 121.8030 | ROC AUC: 0.987 | PR AUC: 0.640 | MCC: 0.394


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 1681.0582 | Val Loss: 146.7806 | ROC AUC: 0.986 | PR AUC: 0.627 | MCC: 0.363


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 1676.6251 | Val Loss: 170.2624 | ROC AUC: 0.988 | PR AUC: 0.651 | MCC: 0.350


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 1683.9991 | Val Loss: 99.4850 | ROC AUC: 0.985 | PR AUC: 0.637 | MCC: 0.413


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 1371.1694 | Val Loss: 145.0430 | ROC AUC: 0.988 | PR AUC: 0.639 | MCC: 0.372


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 1319.4940 | Val Loss: 110.1391 | ROC AUC: 0.987 | PR AUC: 0.641 | MCC: 0.414


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 1305.5765 | Val Loss: 116.3597 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.394


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 1301.9928 | Val Loss: 102.9868 | ROC AUC: 0.986 | PR AUC: 0.634 | MCC: 0.416


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 1295.1747 | Val Loss: 120.8696 | ROC AUC: 0.988 | PR AUC: 0.651 | MCC: 0.407


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 1301.7364 | Val Loss: 110.8722 | ROC AUC: 0.986 | PR AUC: 0.658 | MCC: 0.407


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 1094.8021 | Val Loss: 93.0280 | ROC AUC: 0.987 | PR AUC: 0.656 | MCC: 0.434


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 1053.1534 | Val Loss: 113.5270 | ROC AUC: 0.986 | PR AUC: 0.642 | MCC: 0.405


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 1034.4528 | Val Loss: 86.6123 | ROC AUC: 0.984 | PR AUC: 0.629 | MCC: 0.437


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 1029.9672 | Val Loss: 96.3476 | ROC AUC: 0.986 | PR AUC: 0.643 | MCC: 0.426


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 1026.2533 | Val Loss: 108.3883 | ROC AUC: 0.987 | PR AUC: 0.647 | MCC: 0.416


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 1018.7080 | Val Loss: 74.7417 | ROC AUC: 0.984 | PR AUC: 0.638 | MCC: 0.458


In [9]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=128, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h128_bs128_lr0.001_ep30_20250401-052913
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.92 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=128) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2328.2184 | Val Loss: 141.1325 | ROC AUC: 0.986 | PR AUC: 0.628 | MCC: 0.363


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1825.4780 | Val Loss: 142.5654 | ROC AUC: 0.987 | PR AUC: 0.636 | MCC: 0.376


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 1785.9445 | Val Loss: 154.4732 | ROC AUC: 0.988 | PR AUC: 0.626 | MCC: 0.361


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 1776.9168 | Val Loss: 124.9217 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.386


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 1754.0096 | Val Loss: 134.5164 | ROC AUC: 0.987 | PR AUC: 0.621 | MCC: 0.375


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 1759.6000 | Val Loss: 171.1572 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.340


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 1760.9256 | Val Loss: 168.2418 | ROC AUC: 0.985 | PR AUC: 0.605 | MCC: 0.345


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 1758.8113 | Val Loss: 141.2236 | ROC AUC: 0.986 | PR AUC: 0.651 | MCC: 0.378


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 1758.0704 | Val Loss: 162.8733 | ROC AUC: 0.988 | PR AUC: 0.642 | MCC: 0.360


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 1759.3924 | Val Loss: 166.9655 | ROC AUC: 0.984 | PR AUC: 0.623 | MCC: 0.345


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 1283.3705 | Val Loss: 107.0652 | ROC AUC: 0.986 | PR AUC: 0.660 | MCC: 0.412


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 1217.0594 | Val Loss: 149.6367 | ROC AUC: 0.987 | PR AUC: 0.632 | MCC: 0.368


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 1217.1598 | Val Loss: 105.3324 | ROC AUC: 0.987 | PR AUC: 0.646 | MCC: 0.413


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 1213.6235 | Val Loss: 133.7346 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.384


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 1203.4818 | Val Loss: 145.6880 | ROC AUC: 0.988 | PR AUC: 0.644 | MCC: 0.373


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 1200.0008 | Val Loss: 108.7921 | ROC AUC: 0.986 | PR AUC: 0.624 | MCC: 0.400


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 1197.0745 | Val Loss: 119.9393 | ROC AUC: 0.984 | PR AUC: 0.608 | MCC: 0.382


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 1201.3801 | Val Loss: 106.7702 | ROC AUC: 0.987 | PR AUC: 0.637 | MCC: 0.409


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 1203.2906 | Val Loss: 117.5022 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.394


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 888.2281 | Val Loss: 90.2595 | ROC AUC: 0.988 | PR AUC: 0.642 | MCC: 0.441


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 842.8128 | Val Loss: 85.2345 | ROC AUC: 0.987 | PR AUC: 0.631 | MCC: 0.445


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 836.2232 | Val Loss: 107.3878 | ROC AUC: 0.986 | PR AUC: 0.635 | MCC: 0.408


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 837.8560 | Val Loss: 76.1065 | ROC AUC: 0.985 | PR AUC: 0.625 | MCC: 0.446


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 836.5286 | Val Loss: 82.6572 | ROC AUC: 0.987 | PR AUC: 0.646 | MCC: 0.446


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 829.9343 | Val Loss: 101.4233 | ROC AUC: 0.988 | PR AUC: 0.641 | MCC: 0.419


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 827.6719 | Val Loss: 97.4677 | ROC AUC: 0.988 | PR AUC: 0.658 | MCC: 0.429


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 824.5602 | Val Loss: 92.3477 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.425


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 824.4684 | Val Loss: 112.4822 | ROC AUC: 0.987 | PR AUC: 0.641 | MCC: 0.406


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 825.4096 | Val Loss: 97.5090 | ROC AUC: 0.987 | PR AUC: 0.630 | MCC: 0.422


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 628.7913 | Val Loss: 72.1604 | ROC AUC: 0.987 | PR AUC: 0.650 | MCC: 0.471


In [10]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=256, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h256_bs128_lr0.001_ep30_20250401-061858
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.90 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=256) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2053.8347 | Val Loss: 229.8138 | ROC AUC: 0.987 | PR AUC: 0.618 | MCC: 0.310


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1588.8233 | Val Loss: 152.8427 | ROC AUC: 0.988 | PR AUC: 0.627 | MCC: 0.369


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 1565.5507 | Val Loss: 119.3067 | ROC AUC: 0.987 | PR AUC: 0.624 | MCC: 0.395


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 1556.6047 | Val Loss: 172.4796 | ROC AUC: 0.986 | PR AUC: 0.619 | MCC: 0.346


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 1559.4548 | Val Loss: 148.2712 | ROC AUC: 0.990 | PR AUC: 0.633 | MCC: 0.377


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 1559.8390 | Val Loss: 142.1818 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.373


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 1558.6847 | Val Loss: 125.2523 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.392


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 1555.0498 | Val Loss: 161.6505 | ROC AUC: 0.989 | PR AUC: 0.636 | MCC: 0.363


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 1550.9422 | Val Loss: 114.2115 | ROC AUC: 0.988 | PR AUC: 0.648 | MCC: 0.408


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 1555.3490 | Val Loss: 229.7787 | ROC AUC: 0.987 | PR AUC: 0.612 | MCC: 0.312


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 1551.2154 | Val Loss: 121.3077 | ROC AUC: 0.987 | PR AUC: 0.637 | MCC: 0.394


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 1550.6711 | Val Loss: 191.9358 | ROC AUC: 0.989 | PR AUC: 0.649 | MCC: 0.347


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 1555.4943 | Val Loss: 228.1879 | ROC AUC: 0.988 | PR AUC: 0.630 | MCC: 0.326


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 1551.5698 | Val Loss: 183.3618 | ROC AUC: 0.985 | PR AUC: 0.623 | MCC: 0.335


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 1551.0701 | Val Loss: 158.2805 | ROC AUC: 0.989 | PR AUC: 0.645 | MCC: 0.367


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 1071.6475 | Val Loss: 99.3919 | ROC AUC: 0.987 | PR AUC: 0.647 | MCC: 0.428


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 1029.3992 | Val Loss: 145.9456 | ROC AUC: 0.990 | PR AUC: 0.648 | MCC: 0.382


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 1023.8955 | Val Loss: 75.4154 | ROC AUC: 0.988 | PR AUC: 0.646 | MCC: 0.463


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 1016.4491 | Val Loss: 118.2910 | ROC AUC: 0.986 | PR AUC: 0.632 | MCC: 0.403


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 1008.1196 | Val Loss: 137.5046 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.378


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 1007.5017 | Val Loss: 85.4847 | ROC AUC: 0.987 | PR AUC: 0.632 | MCC: 0.439


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 1010.0389 | Val Loss: 75.9583 | ROC AUC: 0.988 | PR AUC: 0.653 | MCC: 0.473


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 1013.0289 | Val Loss: 95.4321 | ROC AUC: 0.986 | PR AUC: 0.639 | MCC: 0.429


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 1002.9763 | Val Loss: 123.3797 | ROC AUC: 0.987 | PR AUC: 0.623 | MCC: 0.395


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 695.2301 | Val Loss: 78.0055 | ROC AUC: 0.985 | PR AUC: 0.632 | MCC: 0.453


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 668.7200 | Val Loss: 81.0931 | ROC AUC: 0.988 | PR AUC: 0.638 | MCC: 0.463


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 661.1331 | Val Loss: 95.8941 | ROC AUC: 0.986 | PR AUC: 0.643 | MCC: 0.433


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 659.1683 | Val Loss: 92.5871 | ROC AUC: 0.986 | PR AUC: 0.621 | MCC: 0.428


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 655.4341 | Val Loss: 64.2525 | ROC AUC: 0.987 | PR AUC: 0.639 | MCC: 0.486


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 657.1748 | Val Loss: 107.0144 | ROC AUC: 0.988 | PR AUC: 0.652 | MCC: 0.422


In [11]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=512, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h512_bs128_lr0.001_ep30_20250401-072234
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.90 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=512) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 1961.3288 | Val Loss: 159.6353 | ROC AUC: 0.989 | PR AUC: 0.622 | MCC: 0.363


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1557.0879 | Val Loss: 170.5405 | ROC AUC: 0.988 | PR AUC: 0.633 | MCC: 0.351


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 1536.3636 | Val Loss: 115.0179 | ROC AUC: 0.988 | PR AUC: 0.641 | MCC: 0.399


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 1534.1690 | Val Loss: 138.5587 | ROC AUC: 0.986 | PR AUC: 0.633 | MCC: 0.377


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 1531.9305 | Val Loss: 108.8055 | ROC AUC: 0.985 | PR AUC: 0.625 | MCC: 0.406


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 1525.1428 | Val Loss: 109.2508 | ROC AUC: 0.986 | PR AUC: 0.608 | MCC: 0.401


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 1535.3477 | Val Loss: 112.7471 | ROC AUC: 0.987 | PR AUC: 0.624 | MCC: 0.412


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 1535.6617 | Val Loss: 156.4115 | ROC AUC: 0.987 | PR AUC: 0.630 | MCC: 0.359


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 1533.6373 | Val Loss: 83.9969 | ROC AUC: 0.989 | PR AUC: 0.648 | MCC: 0.453


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 1532.0881 | Val Loss: 107.6331 | ROC AUC: 0.991 | PR AUC: 0.665 | MCC: 0.426


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 1531.9673 | Val Loss: 151.9211 | ROC AUC: 0.988 | PR AUC: 0.631 | MCC: 0.370


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 1537.3731 | Val Loss: 207.7579 | ROC AUC: 0.989 | PR AUC: 0.654 | MCC: 0.329


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 1526.9563 | Val Loss: 86.4721 | ROC AUC: 0.986 | PR AUC: 0.625 | MCC: 0.441


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 1540.7314 | Val Loss: 156.2611 | ROC AUC: 0.989 | PR AUC: 0.638 | MCC: 0.373


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 1537.5542 | Val Loss: 95.7621 | ROC AUC: 0.986 | PR AUC: 0.621 | MCC: 0.413


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 1051.9136 | Val Loss: 94.7919 | ROC AUC: 0.988 | PR AUC: 0.649 | MCC: 0.432


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 1015.3338 | Val Loss: 122.0553 | ROC AUC: 0.985 | PR AUC: 0.610 | MCC: 0.389


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 1006.4522 | Val Loss: 85.8736 | ROC AUC: 0.989 | PR AUC: 0.659 | MCC: 0.447


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 1000.3169 | Val Loss: 78.9926 | ROC AUC: 0.986 | PR AUC: 0.639 | MCC: 0.451


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 1000.9327 | Val Loss: 99.5368 | ROC AUC: 0.987 | PR AUC: 0.627 | MCC: 0.423


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 996.5780 | Val Loss: 81.3820 | ROC AUC: 0.984 | PR AUC: 0.631 | MCC: 0.442


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 994.5309 | Val Loss: 106.4630 | ROC AUC: 0.986 | PR AUC: 0.618 | MCC: 0.415


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 994.9463 | Val Loss: 151.9480 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.371


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 989.8707 | Val Loss: 122.6487 | ROC AUC: 0.987 | PR AUC: 0.627 | MCC: 0.395


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 992.0880 | Val Loss: 95.5645 | ROC AUC: 0.987 | PR AUC: 0.629 | MCC: 0.426


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 679.7770 | Val Loss: 73.5784 | ROC AUC: 0.985 | PR AUC: 0.616 | MCC: 0.458


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 657.9531 | Val Loss: 77.1408 | ROC AUC: 0.984 | PR AUC: 0.621 | MCC: 0.447


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 643.1639 | Val Loss: 85.6609 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.451


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 647.4269 | Val Loss: 102.6664 | ROC AUC: 0.989 | PR AUC: 0.651 | MCC: 0.427


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 644.3447 | Val Loss: 105.7595 | ROC AUC: 0.986 | PR AUC: 0.632 | MCC: 0.414


In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-3

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h1024_bs128_lr0.001_ep30_20250401-085811
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.90 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=1024) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 1957.2268 | Val Loss: 142.4008 | ROC AUC: 0.987 | PR AUC: 0.628 | MCC: 0.375


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1558.1308 | Val Loss: 165.5138 | ROC AUC: 0.987 | PR AUC: 0.626 | MCC: 0.356


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 1536.6935 | Val Loss: 186.9460 | ROC AUC: 0.989 | PR AUC: 0.649 | MCC: 0.344


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 1530.9479 | Val Loss: 254.4352 | ROC AUC: 0.988 | PR AUC: 0.623 | MCC: 0.306


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 1529.7708 | Val Loss: 138.3738 | ROC AUC: 0.988 | PR AUC: 0.629 | MCC: 0.381


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 1529.8863 | Val Loss: 240.0452 | ROC AUC: 0.989 | PR AUC: 0.620 | MCC: 0.307


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 1528.6857 | Val Loss: 171.6383 | ROC AUC: 0.990 | PR AUC: 0.623 | MCC: 0.352


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 1521.1513 | Val Loss: 105.2589 | ROC AUC: 0.988 | PR AUC: 0.616 | MCC: 0.424


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 1528.3352 | Val Loss: 126.1394 | ROC AUC: 0.990 | PR AUC: 0.655 | MCC: 0.395


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 1532.8266 | Val Loss: 124.6978 | ROC AUC: 0.990 | PR AUC: 0.663 | MCC: 0.408


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-4

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-6

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [7]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
lr = 1e-5

model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h1024_bs128_lr1e-05_ep30_20250401-022146
[INFO] Loading Dataloader
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.60 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=1024) for 30 epochs...




Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2746.4681 | Val Loss: 116.7348 | ROC AUC: 0.988 | PR AUC: 0.618 | MCC: 0.386


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 844.7784 | Val Loss: 74.8910 | ROC AUC: 0.988 | PR AUC: 0.640 | MCC: 0.447


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 464.5531 | Val Loss: 54.0727 | ROC AUC: 0.989 | PR AUC: 0.647 | MCC: 0.499


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 310.8563 | Val Loss: 65.7824 | ROC AUC: 0.988 | PR AUC: 0.644 | MCC: 0.475


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 240.4056 | Val Loss: 52.1849 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.505


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 190.3872 | Val Loss: 52.6310 | ROC AUC: 0.988 | PR AUC: 0.637 | MCC: 0.506


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 163.0812 | Val Loss: 46.0035 | ROC AUC: 0.988 | PR AUC: 0.646 | MCC: 0.534


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 146.1827 | Val Loss: 52.6643 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.510


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 132.4804 | Val Loss: 50.1694 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.526


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 122.8850 | Val Loss: 52.2132 | ROC AUC: 0.987 | PR AUC: 0.636 | MCC: 0.508


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 115.2397 | Val Loss: 47.8817 | ROC AUC: 0.987 | PR AUC: 0.639 | MCC: 0.529


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 110.9929 | Val Loss: 42.5267 | ROC AUC: 0.987 | PR AUC: 0.648 | MCC: 0.561


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 105.1057 | Val Loss: 46.1547 | ROC AUC: 0.989 | PR AUC: 0.650 | MCC: 0.542


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 101.3321 | Val Loss: 65.5287 | ROC AUC: 0.987 | PR AUC: 0.625 | MCC: 0.469


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 97.5679 | Val Loss: 48.2613 | ROC AUC: 0.988 | PR AUC: 0.648 | MCC: 0.530


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 95.9059 | Val Loss: 55.3150 | ROC AUC: 0.986 | PR AUC: 0.647 | MCC: 0.499


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 94.1026 | Val Loss: 45.8226 | ROC AUC: 0.987 | PR AUC: 0.647 | MCC: 0.538


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 92.5556 | Val Loss: 47.6928 | ROC AUC: 0.988 | PR AUC: 0.642 | MCC: 0.528


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 74.1032 | Val Loss: 40.6657 | ROC AUC: 0.987 | PR AUC: 0.634 | MCC: 0.565


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 68.3083 | Val Loss: 42.0075 | ROC AUC: 0.987 | PR AUC: 0.645 | MCC: 0.563


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 65.4603 | Val Loss: 42.5044 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.549


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 66.5855 | Val Loss: 50.2334 | ROC AUC: 0.987 | PR AUC: 0.645 | MCC: 0.525


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 65.3657 | Val Loss: 37.0782 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.586


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 64.4120 | Val Loss: 37.3654 | ROC AUC: 0.988 | PR AUC: 0.647 | MCC: 0.581


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 65.0250 | Val Loss: 40.7792 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.560


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 65.0299 | Val Loss: 42.4405 | ROC AUC: 0.988 | PR AUC: 0.639 | MCC: 0.546


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 64.6315 | Val Loss: 45.5514 | ROC AUC: 0.987 | PR AUC: 0.628 | MCC: 0.525


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 63.8310 | Val Loss: 40.9893 | ROC AUC: 0.986 | PR AUC: 0.640 | MCC: 0.559


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 63.3009 | Val Loss: 42.6834 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.549


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 55.9526 | Val Loss: 39.0010 | ROC AUC: 0.986 | PR AUC: 0.638 | MCC: 0.565


In [9]:
#Load the best model
import torch
from model import ProteinClassifier
input_dim = cath_df.shape[1]-1
# Define your model architecture (must match the saved model)
model = ProteinClassifier(hidden_dim=1024, input_dim=input_dim)

# Load saved weights
model.load_state_dict(torch.load("./modelData/baseline_h1024_bs128_lr1e-05_ep30_20250401-022146_best.pt", map_location="cpu"))

# Set model to evaluation mode
model.eval()

  model.load_state_dict(torch.load("./modelData/baseline_h1024_bs128_lr1e-05_ep30_20250401-022146_best.pt", map_location="cpu"))


ProteinClassifier(
  (linear1): Linear(in_features=3922, out_features=1024, bias=True)
  (batchnorm1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
  (linear2): Linear(in_features=1024, out_features=1, bias=True)
)

In [10]:
test_model_on_ecod(model, ecod_df)

[INFO] Initializing Streaming Dataset from DataFrame of size 761
[INFO] Streaming V2 init done in 0.34 seconds
[FINAL TEST on ECOD] Loss: 462.5683 | ROC AUC: 0.965 | PR AUC: 0.788 | MCC: 0.704


(462.5683311524644, 0.9652908960888762, 0.7884873464234206, 0.7042552401693547)

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from dataset import StreamingProteinPairDataset, StreamingProteinPairDatasetV2
import time
import tracemalloc  # optional: for memory profiling

# Load a manageable subset for testing
df = pd.read_csv("data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1).head(1000)

BATCH_SIZE = 64

def profile_loader(dataset_class, name=""):
    print(f"\n--- Profiling {name} ---")

    # Start memory profiling
    tracemalloc.start()
    start = time.time_ns()
    dataset = dataset_class(df)
    init_time_ns = time.time_ns() - start
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    print(f"[{name}] Init time: {init_time_ns / 1e6:.2f} ms")
    print(f"[{name}] Peak memory usage: {peak / (1024**2):.2f} MB")

    # Profile batch loading
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4)
    batch_times_ns = []
    for i, (x, y) in enumerate(loader):
        if i >= 5:
            break
        t0 = time.time_ns()
        _ = x.shape, y.shape  # simulate some access
        t1 = time.time_ns()
        batch_times_ns.append(t1 - t0)

    if batch_times_ns:
        avg_batch_time_ms = sum(batch_times_ns) / len(batch_times_ns) 
        print(f"[{name}] Avg batch load time (first 5): {avg_batch_time_ms:.2f} ms")
    else:
        print(f"[{name}] No batches loaded.")

# Run profiling for both
profile_loader(StreamingProteinPairDataset, "Streaming v1")
profile_loader(StreamingProteinPairDatasetV2, "Streaming v2")