In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None) # For Train
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None) # For eval

In [None]:
cath_info = cath_df.info()

In [None]:
ecod_info = ecod_df.info()

In [None]:
# Summarize shapes, classes, and a preview
cath_summary = {
    "shape": cath_df.shape,
    "unique_classes": cath_df[0].nunique(),
    "class_distribution": cath_df[0].value_counts(),
    "head": cath_df.head()
}

ecod_summary = {
    "shape": ecod_df.shape,
    "unique_classes": ecod_df[0].nunique(),
    "class_distribution": ecod_df[0].value_counts(),
    "head": ecod_df.head()
}

In [None]:
print(cath_summary)

In [None]:
print(ecod_summary)

In [None]:
a = [1.2,3.4,5.6],
b = [2.4,3.2,4.5]

z_ab = abs(a[i]-b[i])

2*(1.2-2.4)/1+abs(1.2)+abs(2.4)

Vector in the original dimenstion

30 proteins 
30*29/2-> pairs of proteins

In [None]:
cath_df.iloc[0]

In [None]:
for i in range(0,2):
    print(type(cath_df.iloc[0,i]))

In [None]:
# Plot class distribution for CATH
cath_class_counts = cath_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(cath_class_counts, bins=30, edgecolor='black')
plt.title('Class Frequency Distribution in CATH Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot class distribution for ECOD
ecod_class_counts = ecod_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(ecod_class_counts, bins=20, edgecolor='black')
plt.title('Class Frequency Distribution in ECOD Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()


### Visualizing Class Imbalance

Each protein in the dataset is associated with a structural class label (column 0), which identifies its 3D(1D in this case) shape category. Understanding how many proteins fall into each class is critical because the dataset is not balanced—some classes have many proteins, while others have very few.

To quantify this, we use `value_counts()` to compute the frequency of each class label and visualize the distribution with a histogram.

#### Why this is important:
During training, we will create pairs of proteins to determine structural similarity. If a particular class contains many proteins, it will generate significantly more pairs, which may bias the model toward frequently occurring classes. This can lead to overfitting and poor generalization, especially on underrepresented classes.

By plotting the class frequency distribution:
- We confirm whether class imbalance is present.
- We motivate the need for sampling strategies such as `WeightedRandomSampler` to correct for this imbalance during training.

This analysis is a key step in understanding the structure of the data and informing how we design the training process.


In [None]:
import sys
import numpy as np

print("Python version:", sys.version)
print("NumPy version:", np.__version__)


In [1]:
import pandas as pd
import torch
import importlib
import train
import dataset
import time
import os
importlib.reload(train)
importlib.reload(dataset)


from train import train_model, test_model_on_ecod

[INFO] Using 9 CPU threads
[INFO] Using 9 CPU threads


In [2]:
# Load datasets
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1)
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None).dropna(axis=1)

print(f"CATH: {cath_df.shape}, ECOD: {ecod_df.shape}")


CATH: (2685, 3923), ECOD: (761, 3923)


In [3]:
num_proteins = 2685

In [4]:
# import pandas as pd
# from cache_utils import cache_pairwise_data

# df = pd.read_csv("./data/cath_moments.tsv", sep="\t", header=None).dropna(axis=1)


# cache_dir = "./cache/cath_"+str(num_proteins)
# merge_dir = cache_dir+"/cath_merged.pkl"

# if not os.path.exists(merge_dir):
#     cache_pairwise_data(df.head(num_proteins), cache_dir=cache_dir, buffer_limit_mb=100)

In [5]:
# from cache_utils import load_cached_parts,load_and_merge_parts
# from dataset import ProteinPairDataset

# tic = time.time_ns()
# # Step 1: Load and merge buffered part_*.pkl files
# # features, labels = load_cached_parts("cache/cath_buffered", max_threads=16)

# #Parallel Load all the parts
# features, labels = load_and_merge_parts(
#     cache_dir=cache_dir,
#     save_path=merge_dir, 
#     max_threads=16
# )

# tac = time.time_ns()
# print("Loaded in ",(tac-tic)/(10**6),"ms")

In [6]:
print("Expected pairs:",num_proteins*(num_proteins-1)//2)

Expected pairs: 3603270


In [7]:
input_dim = cath_df.shape[1]-1
model = train_model(protein_df=cath_df, hidden_dim=64, input_dim=input_dim, streaming=True,batch_size=128,num_epochs=20,val_split=0.1)

Training on: cpu (CPU)
[TensorBoard] Logging to: tensorboard_logs/baseline_h64_bs128_lr0.001_ep20_20250327-211640
[INFO] Loading Dataloader using streaming : True
[INFO] Creating dataloaders... streaming=True
[INFO] Initializing Streaming Dataset from DataFrame of size 450
[INFO] Generated index pairs: 101025 total
[INFO] Streaming init done in 0.00 seconds
[INFO] Initializing Streaming Dataset from DataFrame of size 50
[INFO] Generated index pairs: 1225 total
[INFO] Streaming init done in 0.00 seconds
Training model (hidden_dim=64) for 20 epochs...


Epoch 1/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 1/20 | Train Loss: 102.7773 | Val Loss: 1.5514 | ROC AUC: 0.957 | PR AUC: 0.771 | MCC: 0.672


Epoch 2/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 2/20 | Train Loss: 44.6737 | Val Loss: 4.2816 | ROC AUC: 0.954 | PR AUC: 0.728 | MCC: 0.542


Epoch 3/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 3/20 | Train Loss: 29.8739 | Val Loss: 2.4945 | ROC AUC: 0.958 | PR AUC: 0.776 | MCC: 0.609


Epoch 4/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 4/20 | Train Loss: 25.8515 | Val Loss: 2.0758 | ROC AUC: 0.960 | PR AUC: 0.778 | MCC: 0.698


Epoch 5/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 5/20 | Train Loss: 16.7408 | Val Loss: 2.2167 | ROC AUC: 0.961 | PR AUC: 0.795 | MCC: 0.742


Epoch 6/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 6/20 | Train Loss: 13.3005 | Val Loss: 2.1734 | ROC AUC: 0.960 | PR AUC: 0.794 | MCC: 0.729


Epoch 7/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 7/20 | Train Loss: 8.9712 | Val Loss: 2.3222 | ROC AUC: 0.960 | PR AUC: 0.807 | MCC: 0.755


Epoch 8/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 8/20 | Train Loss: 10.4235 | Val Loss: 2.3236 | ROC AUC: 0.961 | PR AUC: 0.822 | MCC: 0.769


Epoch 9/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 9/20 | Train Loss: 6.4678 | Val Loss: 2.5766 | ROC AUC: 0.964 | PR AUC: 0.824 | MCC: 0.760


Epoch 10/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 10/20 | Train Loss: 3.9185 | Val Loss: 2.5770 | ROC AUC: 0.964 | PR AUC: 0.822 | MCC: 0.764


Epoch 11/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 11/20 | Train Loss: 6.1373 | Val Loss: 2.6728 | ROC AUC: 0.967 | PR AUC: 0.816 | MCC: 0.762


Epoch 12/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 12/20 | Train Loss: 1.7682 | Val Loss: 3.2582 | ROC AUC: 0.966 | PR AUC: 0.838 | MCC: 0.751


Epoch 13/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 13/20 | Train Loss: 6.5790 | Val Loss: 2.5301 | ROC AUC: 0.965 | PR AUC: 0.830 | MCC: 0.770


Epoch 14/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 14/20 | Train Loss: 1.7887 | Val Loss: 2.6154 | ROC AUC: 0.963 | PR AUC: 0.847 | MCC: 0.768


Epoch 15/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 15/20 | Train Loss: 4.7313 | Val Loss: 3.1773 | ROC AUC: 0.966 | PR AUC: 0.830 | MCC: 0.750


Epoch 16/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 16/20 | Train Loss: 1.8356 | Val Loss: 2.7836 | ROC AUC: 0.965 | PR AUC: 0.842 | MCC: 0.766


Epoch 17/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 17/20 | Train Loss: 3.7013 | Val Loss: 2.9593 | ROC AUC: 0.964 | PR AUC: 0.853 | MCC: 0.782


Epoch 18/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 18/20 | Train Loss: 1.4340 | Val Loss: 2.7829 | ROC AUC: 0.968 | PR AUC: 0.862 | MCC: 0.777


Epoch 19/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 19/20 | Train Loss: 4.1026 | Val Loss: 3.4702 | ROC AUC: 0.967 | PR AUC: 0.856 | MCC: 0.766


Epoch 20/20:   0%|          | 0/790 [00:00<?, ?it/s]

Epoch 20/20 | Train Loss: 1.9134 | Val Loss: 3.6293 | ROC AUC: 0.964 | PR AUC: 0.840 | MCC: 0.777


In [None]:
test_model_on_ecod(model, ecod_df.head(200))

[INFO] Generating nC2 pairs from raw DataFrame of size 200...
