In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None) # For Train
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None) # For eval

In [None]:
cath_info = cath_df.info()

In [None]:
ecod_info = ecod_df.info()

In [None]:
# Summarize shapes, classes, and a preview
cath_summary = {
    "shape": cath_df.shape,
    "unique_classes": cath_df[0].nunique(),
    "class_distribution": cath_df[0].value_counts(),
    "head": cath_df.head()
}

ecod_summary = {
    "shape": ecod_df.shape,
    "unique_classes": ecod_df[0].nunique(),
    "class_distribution": ecod_df[0].value_counts(),
    "head": ecod_df.head()
}

In [None]:
print(cath_summary)

In [None]:
print(ecod_summary)

In [None]:
a = [1.2,3.4,5.6],
b = [2.4,3.2,4.5]

z_ab = abs(a[i]-b[i])

2*(1.2-2.4)/1+abs(1.2)+abs(2.4)

Vector in the original dimenstion

30 proteins 
30*29/2-> pairs of proteins

In [None]:
cath_df.iloc[0]

In [None]:
for i in range(0,2):
    print(type(cath_df.iloc[0,i]))

In [None]:
# Plot class distribution for CATH
cath_class_counts = cath_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(cath_class_counts, bins=30, edgecolor='black')
plt.title('Class Frequency Distribution in CATH Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot class distribution for ECOD
ecod_class_counts = ecod_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(ecod_class_counts, bins=20, edgecolor='black')
plt.title('Class Frequency Distribution in ECOD Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()


### Visualizing Class Imbalance

Each protein in the dataset is associated with a structural class label (column 0), which identifies its 3D(1D in this case) shape category. Understanding how many proteins fall into each class is critical because the dataset is not balanced—some classes have many proteins, while others have very few.

To quantify this, we use `value_counts()` to compute the frequency of each class label and visualize the distribution with a histogram.

#### Why this is important:
During training, we will create pairs of proteins to determine structural similarity. If a particular class contains many proteins, it will generate significantly more pairs, which may bias the model toward frequently occurring classes. This can lead to overfitting and poor generalization, especially on underrepresented classes.

By plotting the class frequency distribution:
- We confirm whether class imbalance is present.
- We motivate the need for sampling strategies such as `WeightedRandomSampler` to correct for this imbalance during training.

This analysis is a key step in understanding the structure of the data and informing how we design the training process.


In [2]:
import sys
import numpy as np

print("Python version:", sys.version)
print("NumPy version:", np.__version__)


Python version: 3.12.2 | packaged by Anaconda, Inc. | (main, Feb 27 2024, 12:57:28) [Clang 14.0.6 ]
NumPy version: 1.26.4


In [3]:
import pandas as pd
import torch
import importlib
import train
import dataset
import time
import os
importlib.reload(train)
importlib.reload(dataset)
import random

random.seed(42)


from train import train_model, test_model_on_ecod

[INFO] Using 9 CPU threads
[INFO] Using 9 CPU threads


In [4]:
# Load datasets
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1)
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None).dropna(axis=1)

print(f"CATH: {cath_df.shape}, ECOD: {ecod_df.shape}")
features,labels=None,None


CATH: (2685, 3923), ECOD: (761, 3923)


In [5]:
num_proteins = 2685

In [6]:
# import pandas as pd
# from cache_utils import cache_pairwise_data

# df = pd.read_csv("./data/cath_moments.tsv", sep="\t", header=None).dropna(axis=1)


# cache_dir = "./cache/cath_"+str(num_proteins)
# merge_dir = cache_dir+"/cath_merged.pkl"

# # if not os.path.exists(merge_dir):
# #     cache_pairwise_data(df.head(num_proteins), cache_dir=cache_dir, buffer_limit_mb=100)

In [7]:
# from cache_utils import load_cached_parts,load_and_merge_parts
# from dataset import ProteinPairDataset

# tic = time.time_ns()
# # Step 1: Load and merge buffered part_*.pkl files
# features, labels = load_cached_parts(cache_dir, max_threads=16)

# # #Parallel Load all the parts
# # features, labels = load_and_merge_parts(
# #     cache_dir=cache_dir,
# #     save_path=merge_dir, 
# #     max_threads=16
# # )

# tac = time.time_ns()
# print("Loaded in ",(tac-tic)/(10**6),"ms")

In [8]:
# from cache_utils import load_cached_parts

# # Load cached features/labels 
# features, labels = load_cached_parts("./cache/cath_"+str(num_proteins))

# # Check dimensions
# print(f"[INFO] Loaded {len(features)} pairs with shape {features[0].shape}")

In [12]:
print("Expected pairs:",num_proteins*(num_proteins-1)//2)

Expected pairs: 3603270


In [13]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=32, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=100,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=32, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=100,val_split=0.1,lr=1e-3)

-------------------- OTG --------------------
Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h32_bs128_lr0.001_ep100_20250329-015511
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.65 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=32) for 100 epochs...


Epoch 1/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/100 | Train Loss: 3520.1832 | Val Loss: 198.9830 | ROC AUC: 0.983 | PR AUC: 0.585 | MCC: 0.299


Epoch 2/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/100 | Train Loss: 2992.8175 | Val Loss: 190.6895 | ROC AUC: 0.984 | PR AUC: 0.582 | MCC: 0.308


Epoch 3/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/100 | Train Loss: 2947.7973 | Val Loss: 243.5905 | ROC AUC: 0.983 | PR AUC: 0.610 | MCC: 0.273


Epoch 4/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/100 | Train Loss: 2937.2675 | Val Loss: 217.8781 | ROC AUC: 0.983 | PR AUC: 0.599 | MCC: 0.293


Epoch 5/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/100 | Train Loss: 2920.0295 | Val Loss: 185.8627 | ROC AUC: 0.983 | PR AUC: 0.617 | MCC: 0.312


Epoch 6/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/100 | Train Loss: 2921.6869 | Val Loss: 221.5775 | ROC AUC: 0.983 | PR AUC: 0.595 | MCC: 0.294


Epoch 7/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/100 | Train Loss: 2915.6762 | Val Loss: 191.5993 | ROC AUC: 0.985 | PR AUC: 0.614 | MCC: 0.310


Epoch 8/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/100 | Train Loss: 2906.9275 | Val Loss: 227.1589 | ROC AUC: 0.985 | PR AUC: 0.614 | MCC: 0.293


Epoch 9/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/100 | Train Loss: 2908.5714 | Val Loss: 234.8356 | ROC AUC: 0.986 | PR AUC: 0.626 | MCC: 0.289


Epoch 10/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/100 | Train Loss: 2888.8267 | Val Loss: 222.6163 | ROC AUC: 0.985 | PR AUC: 0.598 | MCC: 0.292


Epoch 11/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/100 | Train Loss: 2906.3659 | Val Loss: 218.6133 | ROC AUC: 0.983 | PR AUC: 0.574 | MCC: 0.289


Epoch 12/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/100 | Train Loss: 2898.6091 | Val Loss: 135.0981 | ROC AUC: 0.985 | PR AUC: 0.612 | MCC: 0.356


Epoch 13/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/100 | Train Loss: 2892.7979 | Val Loss: 151.0942 | ROC AUC: 0.984 | PR AUC: 0.602 | MCC: 0.338


Epoch 14/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/100 | Train Loss: 2898.1178 | Val Loss: 152.4809 | ROC AUC: 0.984 | PR AUC: 0.612 | MCC: 0.334


Epoch 15/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/100 | Train Loss: 2892.4190 | Val Loss: 218.6225 | ROC AUC: 0.984 | PR AUC: 0.606 | MCC: 0.295


Epoch 16/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/100 | Train Loss: 2899.3027 | Val Loss: 147.6819 | ROC AUC: 0.985 | PR AUC: 0.610 | MCC: 0.346


Epoch 17/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/100 | Train Loss: 2889.7413 | Val Loss: 182.1471 | ROC AUC: 0.984 | PR AUC: 0.618 | MCC: 0.316


Epoch 18/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/100 | Train Loss: 2897.9802 | Val Loss: 182.4980 | ROC AUC: 0.983 | PR AUC: 0.579 | MCC: 0.309


Epoch 19/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/100 | Train Loss: 2885.2197 | Val Loss: 237.5616 | ROC AUC: 0.985 | PR AUC: 0.608 | MCC: 0.288


Epoch 20/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/100 | Train Loss: 2883.7221 | Val Loss: 184.3162 | ROC AUC: 0.986 | PR AUC: 0.615 | MCC: 0.319


Epoch 21/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/100 | Train Loss: 2906.6591 | Val Loss: 186.5171 | ROC AUC: 0.984 | PR AUC: 0.610 | MCC: 0.317


Epoch 22/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/100 | Train Loss: 2900.4765 | Val Loss: 180.3560 | ROC AUC: 0.983 | PR AUC: 0.595 | MCC: 0.312


Epoch 23/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/100 | Train Loss: 2902.4215 | Val Loss: 173.6331 | ROC AUC: 0.986 | PR AUC: 0.605 | MCC: 0.329


Epoch 24/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/100 | Train Loss: 2902.5077 | Val Loss: 204.4029 | ROC AUC: 0.985 | PR AUC: 0.613 | MCC: 0.309


Epoch 25/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/100 | Train Loss: 2894.5058 | Val Loss: 193.0706 | ROC AUC: 0.983 | PR AUC: 0.614 | MCC: 0.300


Epoch 26/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/100 | Train Loss: 2899.2790 | Val Loss: 203.7092 | ROC AUC: 0.985 | PR AUC: 0.619 | MCC: 0.306


Epoch 27/100:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/100 | Train Loss: 2889.9720 | Val Loss: 174.1644 | ROC AUC: 0.984 | PR AUC: 0.613 | MCC: 0.320


Epoch 28/100:   0%|          | 0/25336 [00:00<?, ?it/s]

: 

In [10]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=64, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=64, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)

-------------------- OTG --------------------
Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h64_bs128_lr0.001_ep30_20250329-023139
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.60 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=64) for 30 epochs...


Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2792.1561 | Val Loss: 193.5107 | ROC AUC: 0.984 | PR AUC: 0.613 | MCC: 0.317


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 2268.0892 | Val Loss: 146.4664 | ROC AUC: 0.984 | PR AUC: 0.643 | MCC: 0.353


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 2238.5265 | Val Loss: 142.7934 | ROC AUC: 0.984 | PR AUC: 0.605 | MCC: 0.356


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 2217.5006 | Val Loss: 191.6465 | ROC AUC: 0.983 | PR AUC: 0.612 | MCC: 0.316


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 2204.0521 | Val Loss: 215.1471 | ROC AUC: 0.985 | PR AUC: 0.644 | MCC: 0.308


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 2194.2527 | Val Loss: 182.0491 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.330


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 2196.1210 | Val Loss: 187.2602 | ROC AUC: 0.988 | PR AUC: 0.635 | MCC: 0.333


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 8/30 | Train Loss: 2185.2612 | Val Loss: 165.7398 | ROC AUC: 0.987 | PR AUC: 0.644 | MCC: 0.349


Epoch 9/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 9/30 | Train Loss: 2185.0298 | Val Loss: 159.7106 | ROC AUC: 0.987 | PR AUC: 0.633 | MCC: 0.350


Epoch 10/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 10/30 | Train Loss: 2178.3192 | Val Loss: 217.3017 | ROC AUC: 0.986 | PR AUC: 0.606 | MCC: 0.310


Epoch 11/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 11/30 | Train Loss: 2182.0822 | Val Loss: 159.7148 | ROC AUC: 0.986 | PR AUC: 0.609 | MCC: 0.346


Epoch 12/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 12/30 | Train Loss: 2179.8685 | Val Loss: 148.9172 | ROC AUC: 0.985 | PR AUC: 0.625 | MCC: 0.354


Epoch 13/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 13/30 | Train Loss: 2172.1163 | Val Loss: 169.1241 | ROC AUC: 0.984 | PR AUC: 0.618 | MCC: 0.339


Epoch 14/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 14/30 | Train Loss: 2176.3748 | Val Loss: 162.2686 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.349


Epoch 15/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 15/30 | Train Loss: 2171.1682 | Val Loss: 165.4771 | ROC AUC: 0.985 | PR AUC: 0.612 | MCC: 0.344


Epoch 16/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 16/30 | Train Loss: 2181.5160 | Val Loss: 166.4884 | ROC AUC: 0.987 | PR AUC: 0.628 | MCC: 0.341


Epoch 17/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 17/30 | Train Loss: 2177.7757 | Val Loss: 185.7901 | ROC AUC: 0.988 | PR AUC: 0.633 | MCC: 0.335


Epoch 18/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 18/30 | Train Loss: 2175.9752 | Val Loss: 155.2974 | ROC AUC: 0.985 | PR AUC: 0.637 | MCC: 0.351


Epoch 19/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 19/30 | Train Loss: 2169.1826 | Val Loss: 145.1825 | ROC AUC: 0.987 | PR AUC: 0.619 | MCC: 0.367


Epoch 20/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 20/30 | Train Loss: 2164.4379 | Val Loss: 171.0112 | ROC AUC: 0.983 | PR AUC: 0.598 | MCC: 0.328


Epoch 21/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 21/30 | Train Loss: 2176.3894 | Val Loss: 146.3963 | ROC AUC: 0.986 | PR AUC: 0.621 | MCC: 0.367


Epoch 22/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 22/30 | Train Loss: 2168.1331 | Val Loss: 131.9585 | ROC AUC: 0.986 | PR AUC: 0.632 | MCC: 0.373


Epoch 23/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 23/30 | Train Loss: 2165.9831 | Val Loss: 161.5849 | ROC AUC: 0.985 | PR AUC: 0.629 | MCC: 0.349


Epoch 24/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 24/30 | Train Loss: 2172.6239 | Val Loss: 122.7869 | ROC AUC: 0.986 | PR AUC: 0.622 | MCC: 0.389


Epoch 25/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 25/30 | Train Loss: 2165.4997 | Val Loss: 238.7690 | ROC AUC: 0.988 | PR AUC: 0.643 | MCC: 0.303


Epoch 26/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 26/30 | Train Loss: 2174.3465 | Val Loss: 151.9889 | ROC AUC: 0.986 | PR AUC: 0.630 | MCC: 0.355


Epoch 27/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 27/30 | Train Loss: 2169.6083 | Val Loss: 166.9017 | ROC AUC: 0.986 | PR AUC: 0.643 | MCC: 0.348


Epoch 28/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 28/30 | Train Loss: 2170.4923 | Val Loss: 221.0865 | ROC AUC: 0.986 | PR AUC: 0.624 | MCC: 0.314


Epoch 29/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 29/30 | Train Loss: 2171.5792 | Val Loss: 171.1932 | ROC AUC: 0.988 | PR AUC: 0.636 | MCC: 0.347


Epoch 30/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 30/30 | Train Loss: 2162.7835 | Val Loss: 161.3375 | ROC AUC: 0.988 | PR AUC: 0.637 | MCC: 0.351


In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=128, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=128, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)

-------------------- OTG --------------------
Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h128_bs128_lr0.001_ep30_20250329-031057
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 2685
[INFO] Streaming V2 init done in 0.89 seconds
[INFO] Creating traindata with 3242943
[INFO] Creating valdata with 360327
Training model (hidden_dim=128) for 30 epochs...


Epoch 1/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 1/30 | Train Loss: 2324.5780 | Val Loss: 168.7088 | ROC AUC: 0.988 | PR AUC: 0.632 | MCC: 0.352


Epoch 2/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 2/30 | Train Loss: 1819.3355 | Val Loss: 169.6321 | ROC AUC: 0.986 | PR AUC: 0.635 | MCC: 0.352


Epoch 3/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 3/30 | Train Loss: 1778.6006 | Val Loss: 155.0705 | ROC AUC: 0.986 | PR AUC: 0.631 | MCC: 0.365


Epoch 4/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 4/30 | Train Loss: 1766.4223 | Val Loss: 150.4354 | ROC AUC: 0.987 | PR AUC: 0.613 | MCC: 0.367


Epoch 5/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 5/30 | Train Loss: 1758.9423 | Val Loss: 149.0742 | ROC AUC: 0.986 | PR AUC: 0.599 | MCC: 0.363


Epoch 6/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 6/30 | Train Loss: 1758.9229 | Val Loss: 163.9209 | ROC AUC: 0.987 | PR AUC: 0.630 | MCC: 0.350


Epoch 7/30:   0%|          | 0/25336 [00:00<?, ?it/s]

Epoch 7/30 | Train Loss: 1741.5289 | Val Loss: 174.9852 | ROC AUC: 0.988 | PR AUC: 0.641 | MCC: 0.348


Epoch 8/30:   0%|          | 0/25336 [00:00<?, ?it/s]

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=256, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=256, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=512, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=512, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)

In [None]:
from weight_strategy import inverse_class_weighting
input_dim = cath_df.shape[1]-1
streaming = True

if streaming==False:
    print("-"*20, "CACHE","-"*20)
    model = train_model(features=features, labels=labels, hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)
else:
    print("-"*20, "OTG","-"*20)
    model = train_model(protein_df=cath_df.head(num_proteins), hidden_dim=1024, input_dim=input_dim, streaming=streaming,batch_size=128,num_epochs=30,val_split=0.1,lr=1e-3)

In [None]:
#Load the best model
import torch
from model import ProteinClassifier
input_dim = cath_df.shape[1]-1
# Define your model architecture (must match the saved model)
model = ProteinClassifier(hidden_dim=64, input_dim=input_dim)

# Load saved weights
model.load_state_dict(torch.load("./modelData/baseline_h64_bs128_lr0.001_ep20_20250328-041027_best.pt", map_location="cpu"))

# Set model to evaluation mode
model.eval()

In [None]:
test_model_on_ecod(model, ecod_df)

In [None]:
import pandas as pd
import torch
from torch.utils.data import DataLoader
from dataset import StreamingProteinPairDataset, StreamingProteinPairDatasetV2
import time
import tracemalloc  # optional: for memory profiling

# Load a manageable subset for testing
df = pd.read_csv("data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1).head(1000)

BATCH_SIZE = 64

def profile_loader(dataset_class, name=""):
    print(f"\n--- Profiling {name} ---")

    # Start memory profiling
    tracemalloc.start()
    start = time.time_ns()
    dataset = dataset_class(df)
    init_time_ns = time.time_ns() - start
    current, peak = tracemalloc.get_traced_memory()
    tracemalloc.stop()

    print(f"[{name}] Init time: {init_time_ns / 1e6:.2f} ms")
    print(f"[{name}] Peak memory usage: {peak / (1024**2):.2f} MB")

    # Profile batch loading
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4)
    batch_times_ns = []
    for i, (x, y) in enumerate(loader):
        if i >= 5:
            break
        t0 = time.time_ns()
        _ = x.shape, y.shape  # simulate some access
        t1 = time.time_ns()
        batch_times_ns.append(t1 - t0)

    if batch_times_ns:
        avg_batch_time_ms = sum(batch_times_ns) / len(batch_times_ns) 
        print(f"[{name}] Avg batch load time (first 5): {avg_batch_time_ms:.2f} ms")
    else:
        print(f"[{name}] No batches loaded.")

# Run profiling for both
profile_loader(StreamingProteinPairDataset, "Streaming v1")
profile_loader(StreamingProteinPairDatasetV2, "Streaming v2")