In [None]:
import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt

In [None]:
# Load the dataset
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None) # For Train
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None) # For eval

In [None]:
cath_info = cath_df.info()

In [None]:
ecod_info = ecod_df.info()

In [None]:
# Summarize shapes, classes, and a preview
cath_summary = {
    "shape": cath_df.shape,
    "unique_classes": cath_df[0].nunique(),
    "class_distribution": cath_df[0].value_counts(),
    "head": cath_df.head()
}

ecod_summary = {
    "shape": ecod_df.shape,
    "unique_classes": ecod_df[0].nunique(),
    "class_distribution": ecod_df[0].value_counts(),
    "head": ecod_df.head()
}

In [None]:
print(cath_summary)

In [None]:
print(ecod_summary)

In [None]:
for i in range(0,2):
    print(type(cath_df.iloc[0,i]))

In [None]:
# Plot class distribution for CATH
cath_class_counts = cath_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(cath_class_counts, bins=30, edgecolor='black')
plt.title('Class Frequency Distribution in CATH Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plot class distribution for ECOD
ecod_class_counts = ecod_df[0].value_counts()
plt.figure(figsize=(12, 4))
plt.hist(ecod_class_counts, bins=20, edgecolor='black')
plt.title('Class Frequency Distribution in ECOD Dataset')
plt.xlabel('Number of proteins per class')
plt.ylabel('Number of classes')
plt.grid(True)
plt.tight_layout()
plt.show()


### Visualizing Class Imbalance

Each protein in the dataset is associated with a structural class label (column 0), which identifies its 3D(1D in this case) shape category. Understanding how many proteins fall into each class is critical because the dataset is not balanced—some classes have many proteins, while others have very few.

To quantify this, we use `value_counts()` to compute the frequency of each class label and visualize the distribution with a histogram.

#### Why this is important:
During training, we will create pairs of proteins to determine structural similarity. If a particular class contains many proteins, it will generate significantly more pairs, which may bias the model toward frequently occurring classes. This can lead to overfitting and poor generalization, especially on underrepresented classes.

By plotting the class frequency distribution:
- We confirm whether class imbalance is present.
- We motivate the need for sampling strategies such as `WeightedRandomSampler` to correct for this imbalance during training.

This analysis is a key step in understanding the structure of the data and informing how we design the training process.


In [11]:
import pandas as pd
import torch
import importlib
import train

importlib.reload(train)

from train import train_model, test_model_on_ecod

In [4]:
# Load datasets
cath_df = pd.read_csv("./data/cath_moments.tsv", sep='\t', header=None).dropna(axis=1)
ecod_df = pd.read_csv("./data/ecod_moments.tsv", sep='\t', header=None).dropna(axis=1)

print(f"CATH: {cath_df.shape}, ECOD: {ecod_df.shape}")


CATH: (2685, 3923), ECOD: (761, 3923)


In [5]:
# import pandas as pd
# from cache_utils import cache_pairwise_data

# df = pd.read_csv("./data/cath_moments.tsv", sep="\t", header=None).dropna(axis=1)
# cache_pairwise_data(df.head(2685), cache_dir="./cache/cath_buffered", buffer_limit_mb=100)

In [6]:
from cache_utils import load_cached_parts
from dataset import ProteinPairDataset

tic = time.time_ns()
# Step 1: Load and merge buffered part_*.pkl files
features, labels = load_cached_parts("./cache/cath_buffered")
tac = time.time_ns()
print("Loaded in ",(tac-tic)/(10**6),"ms")

Loading 3605 cached parts from: ./cache/cath_buffered
Loaded  0
Loaded  1
Loaded  2
Loaded  3
Loaded  4
Loaded  5
Loaded  6
Loaded  7
Loaded  8
Loaded  9
Loaded  10
Loaded  11
Loaded  12
Loaded  13
Loaded  14
Loaded  15
Loaded  16
Loaded  17
Loaded  18
Loaded  19
Loaded  20
Loaded  21
Loaded  22
Loaded  23
Loaded  24
Loaded  25
Loaded  26
Loaded  27
Loaded  28
Loaded  29
Loaded  30
Loaded  31
Loaded  32
Loaded  33
Loaded  34
Loaded  35
Loaded  36
Loaded  37
Loaded  38
Loaded  39
Loaded  40
Loaded  41
Loaded  42
Loaded  43
Loaded  44
Loaded  45
Loaded  46
Loaded  47
Loaded  48
Loaded  49
Loaded  50
Loaded  51
Loaded  52
Loaded  53
Loaded  54
Loaded  55
Loaded  56
Loaded  57
Loaded  58
Loaded  59
Loaded  60
Loaded  61
Loaded  62
Loaded  63
Loaded  64
Loaded  65
Loaded  66
Loaded  67
Loaded  68
Loaded  69
Loaded  70
Loaded  71
Loaded  72
Loaded  73
Loaded  74
Loaded  75
Loaded  76
Loaded  77
Loaded  78
Loaded  79
Loaded  80
Loaded  81
Loaded  82
Loaded  83
Loaded  84
Loaded  85
Loaded  86

In [8]:
# Step 2: Pass them into your dataset directly
dataset = ProteinPairDataset(features=features, labels=labels)

In [7]:
2685*2684/2

3603270.0

In [9]:
type(dataset)

dataset.ProteinPairDataset

In [None]:
from cache_utils import cache_pairwise_data
import time

# caching proteins for fast prototyping
tic = time.time_ns()
num_items = 2685
cache_pairwise_data(cath_df.head(num_items),cache_dir="cache/test_cache", buffer_limit_mb=10)
tac = time.time_ns()
print((tac-tic)/(10**6),"ms")After, thi

In [None]:


# Optional: re-import the functions explicitly
from train import train_model, test_model_on_ecod

# # Train logistic regression
# model_logistic = train_model(cath_df, hidden_dim=None, num_epochs=5, batch_size=8)

# Train small neural net
model = train_model(features=features, labels=labels, hidden_dim=64, num_epochs=5)

Training model (hidden_dim=64) for 5 epochs...


Epoch 1/5:   0%|          | 0/720654 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve

Epoch 1/5 | Train Loss: 227275.4313 | Val Loss: 41369.0562 | ROC AUC: 0.965 | PR AUC: 0.511 | MCC: 0.222


Epoch 2/5:   0%|          | 0/720654 [00:00<?, ?it/s]

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The Jupyter serve