In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None) # For Train
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None) # For eval

In [None]:
cath_info = cath_df.info()

In [None]:
ecod_info = ecod_df.info()

In [None]:
# Summarize shapes, classes, and a preview
cath_summary = {
    "shape": cath_df.shape,
    "unique_classes": cath_df[0].nunique(),
    "class_distribution": cath_df[0].value_counts(),
    "head": cath_df.head()
}

ecod_summary = {
    "shape": ecod_df.shape,
    "unique_classes": ecod_df[0].nunique(),
    "class_distribution": ecod_df[0].value_counts(),
    "head": ecod_df.head()
}

In [None]:
print(cath_summary)

In [None]:
print(ecod_summary)

In [None]:
a = [1.2,3.4,5.6],
b = [2.4,3.2,4.5]

z_ab = abs(a[i]-b[i])

2*(1.2-2.4)/1+abs(1.2)+abs(2.4)

Vector in the original dimenstion

30 proteins 
30*29/2-> pairs of proteins

In [None]:
cath_df.iloc[0]

In [None]:
for i in range(0,2):
    print(type(cath_df.iloc[0,i]))

In [None]:
# Plot class distribution for CATH
cath_class_counts = cath_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(cath_class_counts, bins=30, edgecolor='black')
plt.title('Class Frequency Distribution in CATH Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot class distribution for ECOD
ecod_class_counts = ecod_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(ecod_class_counts, bins=20, edgecolor='black')
plt.title('Class Frequency Distribution in ECOD Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()


### Visualizing Class Imbalance

Each protein in the dataset is associated with a structural class label (column 0), which identifies its 3D(1D in this case) shape category. Understanding how many proteins fall into each class is critical because the dataset is not balanced—some classes have many proteins, while others have very few.

To quantify this, we use `value_counts()` to compute the frequency of each class label and visualize the distribution with a histogram.

#### Why this is important:
During training, we will create pairs of proteins to determine structural similarity. If a particular class contains many proteins, it will generate significantly more pairs, which may bias the model toward frequently occurring classes. This can lead to overfitting and poor generalization, especially on underrepresented classes.

By plotting the class frequency distribution:
- We confirm whether class imbalance is present.
- We motivate the need for sampling strategies such as `WeightedRandomSampler` to correct for this imbalance during training.

This analysis is a key step in understanding the structure of the data and informing how we design the training process.


In [None]:
import sys
import numpy as np

print("Python version:", sys.version)
print("NumPy version:", np.__version__)


In [1]:
import pandas as pd
import torch
import importlib
import train
import dataset
import time
import os
importlib.reload(train)
importlib.reload(dataset)
import random

random.seed(42)


from train import train_model, test_model_on_ecod

[INFO] Using 9 CPU threads
[INFO] Using 9 CPU threads


In [2]:
# Load datasets
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1)
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None).dropna(axis=1)

print(f"CATH: {cath_df.shape}, ECOD: {ecod_df.shape}")
features,labels=None,None


CATH: (2685, 3923), ECOD: (761, 3923)


In [3]:
num_proteins = 2685

In [4]:
# import pandas as pd
# from cache_utils import cache_pairwise_data

# df = pd.read_csv("./data/cath_moments.tsv", sep="\t", header=None).dropna(axis=1)


# cache_dir = "./cache/cath_"+str(num_proteins)
# merge_dir = cache_dir+"/cath_merged.pkl"

# # if not os.path.exists(merge_dir):
# #     cache_pairwise_data(df.head(num_proteins), cache_dir=cache_dir, buffer_limit_mb=100)

In [None]:
# from cache_utils import load_cached_parts,load_and_merge_parts
# from dataset import ProteinPairDataset

# tic = time.time_ns()
# # Step 1: Load and merge buffered part_*.pkl files
# features, labels = load_cached_parts(cache_dir, max_threads=16)

# # #Parallel Load all the parts
# # features, labels = load_and_merge_parts(
# #     cache_dir=cache_dir,
# #     save_path=merge_dir, 
# #     max_threads=16
# # )

# tac = time.time_ns()
# print("Loaded in ",(tac-tic)/(10**6),"ms")

In [None]:
# from cache_utils import load_cached_parts

# # Load cached features/labels 
# features, labels = load_cached_parts("./cache/cath_"+str(num_proteins))

# # Check dimensions
# print(f"[INFO] Loaded {len(features)} pairs with shape {features[0].shape}")

In [5]:
print("Expected pairs:",num_proteins*(num_proteins-1)//2)

Expected pairs: 3603270


In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True
lr = 1e-4

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=32, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=32, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=64, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=64, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=128, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=128, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=256, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=256, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=512, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=512, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

In [8]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

lr = 0.000005
if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

-------------------- OTG --------------------
Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h1024_bs128_lr5e-06_ep30_20250330-062545
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.57 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=1024) for 30 epochs...


Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 3770.6268 | Val Loss: 208.8888 | ROC AUC: 0.988 | PR AUC: 0.611 | MCC: 0.317


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1626.6882 | Val Loss: 133.7798 | ROC AUC: 0.988 | PR AUC: 0.623 | MCC: 0.372


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 952.2820 | Val Loss: 82.9622 | ROC AUC: 0.989 | PR AUC: 0.642 | MCC: 0.436


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 643.5892 | Val Loss: 72.6091 | ROC AUC: 0.988 | PR AUC: 0.634 | MCC: 0.449


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 470.4529 | Val Loss: 68.9474 | ROC AUC: 0.989 | PR AUC: 0.644 | MCC: 0.463


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 362.4367 | Val Loss: 62.9589 | ROC AUC: 0.988 | PR AUC: 0.646 | MCC: 0.477


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 296.8323 | Val Loss: 57.6975 | ROC AUC: 0.987 | PR AUC: 0.636 | MCC: 0.488


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 243.5714 | Val Loss: 54.7768 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.502


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 209.7473 | Val Loss: 57.1940 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.496


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 183.3671 | Val Loss: 48.3058 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.524


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 163.0535 | Val Loss: 52.0852 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.508


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 146.4547 | Val Loss: 47.2597 | ROC AUC: 0.987 | PR AUC: 0.631 | MCC: 0.525


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 135.9912 | Val Loss: 47.6240 | ROC AUC: 0.987 | PR AUC: 0.637 | MCC: 0.522


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 124.0646 | Val Loss: 45.8735 | ROC AUC: 0.987 | PR AUC: 0.639 | MCC: 0.532


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 116.1758 | Val Loss: 42.4412 | ROC AUC: 0.986 | PR AUC: 0.634 | MCC: 0.535


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 109.8195 | Val Loss: 45.2003 | ROC AUC: 0.987 | PR AUC: 0.639 | MCC: 0.539


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 102.2645 | Val Loss: 46.2193 | ROC AUC: 0.987 | PR AUC: 0.641 | MCC: 0.532


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 98.3857 | Val Loss: 48.7326 | ROC AUC: 0.987 | PR AUC: 0.640 | MCC: 0.522


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 94.5555 | Val Loss: 46.2653 | ROC AUC: 0.987 | PR AUC: 0.633 | MCC: 0.532


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 91.3819 | Val Loss: 47.5565 | ROC AUC: 0.986 | PR AUC: 0.632 | MCC: 0.522


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 88.0607 | Val Loss: 41.5615 | ROC AUC: 0.986 | PR AUC: 0.637 | MCC: 0.555


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 85.4118 | Val Loss: 50.1963 | ROC AUC: 0.986 | PR AUC: 0.637 | MCC: 0.515


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 82.3716 | Val Loss: 45.7618 | ROC AUC: 0.987 | PR AUC: 0.638 | MCC: 0.529


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 81.8189 | Val Loss: 44.4549 | ROC AUC: 0.986 | PR AUC: 0.628 | MCC: 0.540


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 80.0545 | Val Loss: 49.8281 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.522


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 76.3627 | Val Loss: 46.1886 | ROC AUC: 0.987 | PR AUC: 0.638 | MCC: 0.531


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 75.8855 | Val Loss: 40.6204 | ROC AUC: 0.986 | PR AUC: 0.639 | MCC: 0.560


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 75.0835 | Val Loss: 39.8620 | ROC AUC: 0.988 | PR AUC: 0.645 | MCC: 0.565


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 73.7308 | Val Loss: 37.5907 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.580


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 72.1046 | Val Loss: 37.2721 | ROC AUC: 0.988 | PR AUC: 0.646 | MCC: 0.578


In [10]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

lr = 1e-5
if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=lr)

-------------------- OTG --------------------
Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h1024_bs128_lr1e-05_ep30_20250330-133020
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.89 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=1024) for 30 epochs...


Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2746.5221 | Val Loss: 138.3194 | ROC AUC: 0.989 | PR AUC: 0.625 | MCC: 0.371


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 852.9638 | Val Loss: 87.3493 | ROC AUC: 0.988 | PR AUC: 0.640 | MCC: 0.428


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 460.5240 | Val Loss: 69.1546 | ROC AUC: 0.987 | PR AUC: 0.632 | MCC: 0.455


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 309.8509 | Val Loss: 58.8614 | ROC AUC: 0.987 | PR AUC: 0.640 | MCC: 0.489


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 234.3192 | Val Loss: 59.8611 | ROC AUC: 0.988 | PR AUC: 0.635 | MCC: 0.487


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 190.7421 | Val Loss: 50.1116 | ROC AUC: 0.987 | PR AUC: 0.633 | MCC: 0.510


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 164.8990 | Val Loss: 44.8325 | ROC AUC: 0.986 | PR AUC: 0.628 | MCC: 0.533


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 146.1045 | Val Loss: 44.1459 | ROC AUC: 0.988 | PR AUC: 0.635 | MCC: 0.537


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 133.9301 | Val Loss: 44.8565 | ROC AUC: 0.988 | PR AUC: 0.641 | MCC: 0.537


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 124.6119 | Val Loss: 47.1424 | ROC AUC: 0.986 | PR AUC: 0.622 | MCC: 0.517


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 115.8700 | Val Loss: 48.7630 | ROC AUC: 0.986 | PR AUC: 0.634 | MCC: 0.520


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 107.8231 | Val Loss: 46.8281 | ROC AUC: 0.986 | PR AUC: 0.637 | MCC: 0.533


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 105.8403 | Val Loss: 47.6294 | ROC AUC: 0.985 | PR AUC: 0.623 | MCC: 0.517


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 100.6666 | Val Loss: 63.5519 | ROC AUC: 0.987 | PR AUC: 0.635 | MCC: 0.481


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 98.9007 | Val Loss: 53.2379 | ROC AUC: 0.987 | PR AUC: 0.642 | MCC: 0.508


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 96.5948 | Val Loss: 39.4519 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.563


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 93.7793 | Val Loss: 45.1510 | ROC AUC: 0.988 | PR AUC: 0.644 | MCC: 0.537


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 92.7296 | Val Loss: 38.8475 | ROC AUC: 0.987 | PR AUC: 0.631 | MCC: 0.566


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 91.0217 | Val Loss: 47.9414 | ROC AUC: 0.984 | PR AUC: 0.631 | MCC: 0.525


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 90.2498 | Val Loss: 46.4865 | ROC AUC: 0.986 | PR AUC: 0.638 | MCC: 0.530


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 88.0587 | Val Loss: 48.5960 | ROC AUC: 0.986 | PR AUC: 0.642 | MCC: 0.524


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 88.4131 | Val Loss: 38.7476 | ROC AUC: 0.986 | PR AUC: 0.625 | MCC: 0.565


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 87.4264 | Val Loss: 54.4992 | ROC AUC: 0.988 | PR AUC: 0.637 | MCC: 0.501


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 85.7063 | Val Loss: 39.6342 | ROC AUC: 0.985 | PR AUC: 0.621 | MCC: 0.562


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 84.3025 | Val Loss: 45.9629 | ROC AUC: 0.987 | PR AUC: 0.643 | MCC: 0.533


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 85.0155 | Val Loss: 46.4134 | ROC AUC: 0.985 | PR AUC: 0.617 | MCC: 0.518


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 84.7408 | Val Loss: 42.9891 | ROC AUC: 0.988 | PR AUC: 0.649 | MCC: 0.551


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 83.7393 | Val Loss: 47.6503 | ROC AUC: 0.985 | PR AUC: 0.620 | MCC: 0.520


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 84.8638 | Val Loss: 41.8626 | ROC AUC: 0.987 | PR AUC: 0.627 | MCC: 0.548


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 83.7484 | Val Loss: 48.9844 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.520


In [None]:
#Load the best model
import torch
from model import ProteinClassifier
input_dim = cath_df.shape[1]-1
# Define your model architecture (must match the saved model)
model = ProteinClassifier(hidden_dim=64, input_dim=input_dim)

# Load saved weights
model.load_state_dict(torch.load("./modelData/baseline_h64_bs128_lr0.001_ep20_20250328-041027_best.pt", map_location="cpu"))

# Set model to evaluation mode
model.eval()

In [11]:
test_model_on_ecod(model, ecod_df)

[INFO] Initializing Streaming Dataset from DataFrame of size 761
[INFO] Streaming init done in 0.10 seconds
[FINAL TEST on ECOD] Loss: 431.4965 | ROC AUC: 0.966 | PR AUC: 0.782 | MCC: 0.705


(431.49647467000403, 0.9660356894615169, 0.782256605396797, 0.704535367872464)

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from dataset import StreamingProteinPairDataset, StreamingProteinPairDatasetV2
import time
import tracemalloc  # optional: for memory profiling

# Load a manageable subset for testing
df = pd.read_csv("data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1).head(1000)

BATCH_SIZE = 64

def profile_loader(dataset_class, name=""):
    print(f"\n--- Profiling {name} ---")

    # Start memory profiling
    tracemalloc.start()
    start = time.time_ns()
    dataset = dataset_class(df)
    init_time_ns = time.time_ns() - start
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    print(f"[{name}] Init time: {init_time_ns / 1e6:.2f} ms")
    print(f"[{name}] Peak memory usage: {peak / (1024**2):.2f} MB")

    # Profile batch loading
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4)
    batch_times_ns = []
    for i, (x, y) in enumerate(loader):
        if i >= 5:
            break
        t0 = time.time_ns()
        _ = x.shape, y.shape  # simulate some access
        t1 = time.time_ns()
        batch_times_ns.append(t1 - t0)

    if batch_times_ns:
        avg_batch_time_ms = sum(batch_times_ns) / len(batch_times_ns) 
        print(f"[{name}] Avg batch load time (first 5): {avg_batch_time_ms:.2f} ms")
    else:
        print(f"[{name}] No batches loaded.")

# Run profiling for both
profile_loader(StreamingProteinPairDataset, "Streaming v1")
profile_loader(StreamingProteinPairDatasetV2, "Streaming v2")