In [1]:
# For testing model.py
import torch
from models.model import ptm_model, batch_converter

print("Testing PTM Adapter Model...")

# Set model to evaluation mode
ptm_model.eval()

# Create sample protein sequences
sample_sequences = [
    ("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),
    ("protein2", "ARNDCEQGHILKMFPSTWYV" * 20),
    ("protein3", "MKLLNVINFVFLMFVSSGRGMSVRGQSQDIVCPITCGQDLKKLGLCATLVVAGMVNPNAZK")
]

# Convert sequences to tokens
batch_labels, batch_strs, batch_tokens = batch_converter(sample_sequences)

print(f"Input shape: {batch_tokens.shape}")

# Test forward pass
with torch.no_grad():
    logits = ptm_model(batch_tokens)
    print(f"Output shape: {logits.shape}")
    print(f"Logits range: [{logits.min().item():.3f}, {logits.max().item():.3f}]")
    
    probs = torch.softmax(logits, dim=-1)
    print(f"Probabilities sum: {probs.sum(dim=-1)[0][1].item():.3f}")

# Test gradient computation
ptm_model.train()
batch_size, seq_len = batch_tokens.shape
dummy_targets = torch.randint(0, 10, (batch_size, seq_len))

logits = ptm_model(batch_tokens)
loss_fn = torch.nn.CrossEntropyLoss()
loss = loss_fn(logits.view(-1, 10), dummy_targets.view(-1))

print(f"Loss: {loss.item():.4f}")

loss.backward()

# Check gradients
adapter_grad_norm = torch.norm(torch.cat([p.grad.flatten() for p in ptm_model.adapter.parameters() if p.grad is not None]))
classifier_grad_norm = torch.norm(torch.cat([p.grad.flatten() for p in ptm_model.ptm_classifier.parameters() if p.grad is not None]))

print(f"Adapter gradient norm: {adapter_grad_norm.item():.6f}")
print(f"Classifier gradient norm: {classifier_grad_norm.item():.6f}")

# Check ESM parameters are frozen
esm_grads = [p.grad for p in ptm_model.esm_model.parameters() if p.grad is not None]
print(f"ESM gradients (should be 0): {len(esm_grads)}")

# Parameter summary
total_params = sum(p.numel() for p in ptm_model.parameters())
trainable_params = sum(p.numel() for p in ptm_model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Trainable ratio: {trainable_params/total_params:.1%}")

print("Test complete!")

Testing PTM Adapter Model...
Input shape: torch.Size([3, 402])
Output shape: torch.Size([3, 402, 10])
Logits range: [-45.447, 38.840]
Probabilities sum: 1.000
Loss: 17.5419
Adapter gradient norm: 0.000220
Classifier gradient norm: 333.558075
ESM gradients (should be 0): 0
Total parameters: 651,223,809
Trainable parameters: 180,555
Trainable ratio: 0.0%
Test complete!


In [2]:
# For testing data_loader.py
import os
import pandas as pd
import torch
from esm import pretrained
import sys
sys.path.append('utils')
from data_loader import PTMDataset, collate_fn, get_data_loaders

def test_dataset_creation(csv_path):
    print(f"Testing PTMDataset with {csv_path}")
    
    # Load ESM model components
    esm_model, alphabet = pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    print("✓ ESM2 model loaded successfully")
    
    # Load and inspect CSV
    df = pd.read_csv(csv_path)
    print(f"✓ CSV loaded successfully")
    print(len(df))
    print(f"Sample positions: {df['positions'].iloc[0]}")
    
    # Test dataset creation
    print("Creating PTMDataset...")
    dataset = PTMDataset(csv_path, alphabet, batch_converter)
    print(f"✓ Dataset created successfully")
    print(f"Dataset length: {len(dataset)}")
    print(f"Number of PTM types: {dataset.num_ptm_types}")
    print()
    
    return dataset, alphabet, batch_converter

def test_dataset_items(dataset):
    print("Testing Dataset Items")
    
    # Test getting items
    print("Testing first 5 dataset items...")
    for i in range(min(5, len(dataset))):
        item = dataset[i]
        print(f"Item {i}:")
        print(f"  Sequence length: {item['seq_length']}")
        print(f"  Labels shape: {item['labels'].shape}")
                
        # Show PTM positions for each type
        ptm_info = []
        total_ptms = 0
        for ptm_type in range(item['labels'].shape[1]):
            positions = torch.where(item['labels'][:, ptm_type] == 1)[0].tolist()
            ptm_info.append(f"PTM{ptm_type}: {[p+1 for p in positions]}")  # Convert back to 1-indexed
            total_ptms += len(positions)
        
        print(f"  Total PTMs: {total_ptms}")
        print(f"  PTM positions: {ptm_info}")
        print()

def test_collate_function(dataset, batch_converter, batch_size=5):
    print("Testing Collate Function")
    
    # Create a batch
    actual_batch_size = min(batch_size, len(dataset))
    batch = [dataset[i] for i in range(actual_batch_size)]
    
    print(f"Testing collate function with batch size {len(batch)}...")
    print(f"Input sequence lengths: {[item['seq_length'] for item in batch]}")
    
    collated = collate_fn(batch, batch_converter)
    
    print("Collated batch contents:")
    print(f"  Tokens shape: {collated['tokens'].shape}")
    print(f"  Labels shape: {collated['labels'].shape}")
    print(f"  Attention mask shape: {collated['attention_mask'].shape}")
    print(f"  Sequence lengths: {collated['seq_lengths'].tolist()}")
    
    # Verify shapes match
    batch_size_actual, max_seq_len = collated['tokens'].shape
    expected_labels_shape = (batch_size_actual, max_seq_len, dataset.num_ptm_types)
    actual_labels_shape = collated['labels'].shape
    if actual_labels_shape == expected_labels_shape:
        print("✓ Shapes are consistent")
    else:
        print(f"✗ Shape mismatch: expected {expected_labels_shape}, got {actual_labels_shape}")
    
    # Check token types
    print(f"Token analysis:")
    unique_tokens = torch.unique(collated['tokens']).tolist()
    print(f"  Padding token (1) count: {(collated['tokens'] == 1).sum().item()}")
    print(f"  CLS token (0) count: {(collated['tokens'] == 0).sum().item()}")
    print(f"  EOS token (2) count: {(collated['tokens'] == 2).sum().item()}")
    
    # Check attention mask
    attention_sum = collated['attention_mask'].sum(dim=1)
    print(f"  Attention mask sums (non-padding positions): {attention_sum.tolist()}")
    
    # Verify attention mask matches sequence lengths (+2 for CLS and EOS tokens)
    expected_attention = [seq_len + 2 for seq_len in collated['seq_lengths'].tolist()]
    actual_attention = attention_sum.tolist()
    
    if expected_attention == actual_attention:
        print("✓ Attention mask correctly matches sequence lengths")
    else:
        print(f"⚠️  Attention mask mismatch:")
        print(f"    Expected: {expected_attention}")
        print(f"    Actual: {actual_attention}")
    print()

def test_data_loaders(train_csv, val_csv, test_csv):
    print("Testing Data Loaders")
    
    # Load ESM components
    print("Loading ESM2 model...")
    esm_model, alphabet = pretrained.esm2_t33_650M_UR50D()
    batch_converter = alphabet.get_batch_converter()
    print("Loaded ESM2 model")
    
    print("Creating data loaders...")
    train_loader, val_loader, test_loader = get_data_loaders(
        train_csv, val_csv, test_csv, 
        alphabet, batch_converter,
        batch_size=5, num_workers=0  # Use 0 workers for testing to avoid multiprocessing issues
    )
    print("✓ Data loaders created successfully")
    
    # Test iterating through loaders
    loaders = [("Train", train_loader), ("Validation", val_loader), ("Test", test_loader)]
    
    for loader_name, loader in loaders:
        print(f"Testing {loader_name.lower()} loader...")
        batch_count = 0
        for i, batch in enumerate(loader):
            print(f"  Batch {i}: tokens {batch['tokens'].shape}, labels {batch['labels'].shape}")
            batch_count += 1
            if i >= 4:  # Only test first 5 batches
                break
        print(f"  ✓ {loader_name} loader working ({batch_count} batches tested)")
    
    print("✓ All data loaders test passed")

In [3]:
dataset, alphabet, batch_converter = test_dataset_creation('./data/train.csv')
test_dataset_items(dataset)
test_collate_function(dataset, batch_converter)
test_data_loaders("data/train.csv", "data/val.csv", "data/test.csv")

Testing PTMDataset with ./data/train.csv
✓ ESM2 model loaded successfully
✓ CSV loaded successfully
48018
Sample positions: [[36], None]
Creating PTMDataset...
✓ Dataset created successfully
Dataset length: 48018
Number of PTM types: 2

Testing Dataset Items
Testing first 5 dataset items...
Item 0:
  Sequence length: 79
  Labels shape: torch.Size([79, 2])
  Total PTMs: 1
  PTM positions: ['PTM0: [36]', 'PTM1: []']

Item 1:
  Sequence length: 419
  Labels shape: torch.Size([419, 2])
  Total PTMs: 1
  PTM positions: ['PTM0: []', 'PTM1: [115]']

Item 2:
  Sequence length: 80
  Labels shape: torch.Size([80, 2])
  Total PTMs: 1
  PTM positions: ['PTM0: [38]', 'PTM1: []']

Item 3:
  Sequence length: 430
  Labels shape: torch.Size([430, 2])
  Total PTMs: 1
  PTM positions: ['PTM0: []', 'PTM1: [116]']

Item 4:
  Sequence length: 418
  Labels shape: torch.Size([418, 2])
  Total PTMs: 1
  PTM positions: ['PTM0: []', 'PTM1: [116]']

Testing Collate Function
Testing collate function with batch siz

In [1]:
import torch
import numpy as np
from utils.metrics import PTMMetrics, calculate_class_weights
from torch.utils.data import DataLoader, TensorDataset

def test_ptm_metrics():
    """Test the PTMMetrics class with synthetic data"""
    print("Testing PTMMetrics class...")
    
    # Initialize metrics
    num_ptm_types = 5
    metrics = PTMMetrics(num_ptm_types)
    
    # Test 1: Perfect predictions
    print("Test 1: Perfect predictions")
    metrics.reset()
    
    # Create synthetic data
    batch_size, seq_len = 5, 20
    # In real ESM output, seq_len would include <cls> and <eos> tokens
    # So actual sequence length would be seq_len + 2
    total_len = seq_len + 2  # Adding <cls> at start and <eos> at end

    labels = torch.randint(0, 2, (batch_size, total_len, num_ptm_types)).float()
    predictions = labels.clone() # Make predictions match labels perfectly
    
    # Create mask that excludes <cls> (position 0) and <eos> (position seq_len+1)
    mask = torch.ones(batch_size, total_len)
    mask[:, 0] = 0  # Exclude <cls> token
    mask[:, seq_len+1] = 0  # Exclude <eos> token

    metrics.update(predictions, labels, mask)
    results = metrics.compute(threshold=0.5)
    
    print(f"Position accuracy: {results.get('position_accuracy', 0):.3f}")
    print(f"Exact match accuracy: {results.get('exact_match_accuracy', 0):.3f}")
    print(f"Overall F1: {results.get('overall_f1', 0):.3f}")

    # Test 2: Random predictions
    print("Test 2: Random predictions")
    metrics.reset()
    
    # Create more realistic data with imbalanced classes
    # Include special tokens in the sequence length
    total_len = seq_len + 2  # +2 for <cls> and <eos>
    predictions = torch.sigmoid(torch.randn(batch_size, total_len, num_ptm_types))
    labels = torch.zeros(batch_size, total_len, num_ptm_types)
    
    # Create mask excluding special tokens
    mask = torch.ones(batch_size, total_len)
    mask[:, 0] = 0  # <cls> token
    mask[:, seq_len+1] = 0  # <eos> token

    # Add some positive labels (simulate rare PTMs)
    # Only add labels to valid amino acid positions (not special tokens)
    for i in range(num_ptm_types):
        num_positives = np.random.randint(1, 5)
        # Only choose from valid positions (excluding positions 0 and seq_len+1)
        valid_positions = []
        for b in range(batch_size):
            for s in range(1, seq_len+1):  # Skip position 0 (<cls>) and seq_len+1 (<eos>)
                valid_positions.append((b, s))
        
        # Randomly select positions for positive labels
        selected_positions = np.random.choice(len(valid_positions), 
                                            min(num_positives, len(valid_positions)), 
                                            replace=False)
        
        for pos_idx in selected_positions:
            batch_idx, seq_idx = valid_positions[pos_idx]
            labels[batch_idx, seq_idx, i] = 1

    metrics.update(predictions, labels, mask)
    results = metrics.compute(threshold=0.5)
    
    print(f"Position accuracy: {results.get('position_accuracy', 0):.3f}")
    print(f"Exact match accuracy: {results.get('exact_match_accuracy', 0):.3f}")
    print(f"Overall F1: {results.get('overall_f1', 0):.3f}")

    # Print per-PTM metrics
    print("Per-PTM metrics:")
    for i in range(num_ptm_types):
        ptm_name = f"PTM_{i}"
        if f'{ptm_name}_f1' in results:
            print(f"{ptm_name}: F1={results[f'{ptm_name}_f1']:.3f}, "
                  f"Precision={results[f'{ptm_name}_precision']:.3f}, "
                  f"Recall={results[f'{ptm_name}_recall']:.3f}")

    # Test 3: Testing with padding mask
    print("Test 3: Testing with padding mask")
    metrics.reset()
    
    # Create data with different sequence lengths (simulating padding)
    batch_size, max_seq_len = 10, 30
    # Total length includes <cls>, actual sequence, <eos>, and padding
    total_max_len = max_seq_len + 2
    predictions = torch.sigmoid(torch.randn(batch_size, total_max_len, num_ptm_types))
    labels = torch.zeros(batch_size, total_max_len, num_ptm_types)
    
    # Create mask with different lengths
    mask = torch.zeros(batch_size, total_max_len)
    seq_lengths = [8,10,12,14,16,18,20,22,24,26]  # Different actual sequence lengths

    for i, length in enumerate(seq_lengths):
        # Set mask to 1 for valid amino acid positions only
        # Position 0 is <cls>, positions 1 to length are amino acids, 
        # position length+1 is <eos>, rest is padding
        mask[i, 1:length+1] = 1  # Only amino acid positions
        
        # Add some positive labels only in valid amino acid positions
        for j in range(num_ptm_types):
            if np.random.random() > 0.5:  # 50% chance of having this PTM
                num_pos = np.random.randint(1, min(4, length))
                # Choose positions between 1 and length (amino acid positions only)
                positions = np.random.choice(range(1, length+1), num_pos, replace=False)
                for pos in positions:
                    labels[i, pos, j] = 1
    
    metrics.update(predictions, labels, mask)
    results = metrics.compute(threshold=0.5)

    print(f"Position accuracy: {results.get('position_accuracy', 0):.3f}")
    print(f"Overall F1: {results.get('overall_f1', 0):.3f}")
    
    # Get PTM statistics
    stats = metrics.get_per_ptm_stats()
    print("PTM statistics:")
    for ptm_name, ptm_stats in stats.items():
        print(f"{ptm_name}: {ptm_stats['positive_samples']} positive, "
              f"{ptm_stats['negative_samples']} negative "
              f"(ratio: {ptm_stats['positive_ratio']:.3f})")
    print()

def test_class_weights():
    """Test the calculate_class_weights function"""
    print("Testing calculate_class_weights function...")
    
    num_ptm_types = 4
    positive_rates = [0.05, 0.20, 0.50, 0.7]  # 5%, 20%, 50% positive rates
    
    # Create fake data loader
    data = []
    for _ in range(5):  # 5 batches
        batch_size = 7
        seq_len = 20
        total_len = seq_len + 2  # Add <cls> and <eos>
        
        labels = torch.zeros(batch_size, total_len, num_ptm_types)
        mask = torch.zeros(batch_size, total_len)
        mask[:, 1:seq_len+1] = 1  # Only amino acids are valid
        
        # Add positive labels based on rates
        for ptm_idx, rate in enumerate(positive_rates):
            for b in range(batch_size):
                for pos in range(1, seq_len+1):
                    if torch.rand(1).item() < rate:
                        labels[b, pos, ptm_idx] = 1
        
        data.append({
            'labels': labels,
            'attention_mask': mask
        })
    
    # Calculate weights
    weights = calculate_class_weights(data, num_ptm_types, device='cpu')
    
    print("Class weights (negative, positive):")
    for i in range(num_ptm_types):
        print(f"PTM {i} (rate: {positive_rates[i]:.0%}): "
              f"neg={weights[i, 0]:.3f}, pos={weights[i, 1]:.3f}")
    
    print("Rarer PTMs should have higher positive weights:")
    print(f"PTM 0 pos weight: {weights[0, 1]:.3f} (should be highest)")
    print(f"PTM 1 pos weight: {weights[1, 1]:.3f} (should be middle)")
    print(f"PTM 2 pos weight: {weights[2, 1]:.3f} (should be middle)")
    print(f"PTM 3 pos weight: {weights[3, 1]:.3f} (should be lowest)")
    print()

def test_edge_cases():
    """Test edge cases for the metrics"""
    print("Testing edge cases...")
    
    metrics = PTMMetrics(num_ptm_types=3)
    
    # Test 1: No positive samples for some PTMs
    print("Test 1: No positive samples for PTM 1")
    metrics.reset()
    
    predictions = torch.tensor([[[0.8, 0.2, 0.9], [0.1, 0.3, 0.7]]])
    labels = torch.tensor([[[1, 0, 1], [0, 0, 1]]], dtype=torch.float)
    mask = torch.ones(1, 2)
    
    metrics.update(predictions, labels, mask)
    results = metrics.compute(threshold=0.5)
    
    # Should handle missing PTM gracefully
    print(f"Metrics computed successfully: {len(results)} metrics")
    print(f"PTM_1 metrics present: {'PTM_1_f1' in results}")
    
    # Test 2: All predictions are negative
    print("Test 2: All predictions are negative")
    metrics.reset()
    
    predictions = torch.zeros(2, 5, 3)  # All zeros
    labels = torch.randint(0, 2, (2, 5, 3)).float()
    mask = torch.ones(2, 5)
    
    metrics.update(predictions, labels, mask)
    results = metrics.compute(threshold=0.5)
    
    print(f"Overall recall (should be 0): {results.get('overall_recall', -1):.3f}")
    
    # Test 3: Multiple batches
    print("Test 3: Multiple batches accumulation")
    metrics.reset()
    
    for i in range(3):
        predictions = torch.sigmoid(torch.randn(2, 10, 3))
        labels = torch.randint(0, 2, (2, 10, 3)).float()
        mask = torch.ones(2, 10)
        metrics.update(predictions, labels, mask)
    
    results = metrics.compute()
    stats = metrics.get_per_ptm_stats()
    
    total_samples = sum(stats[f'PTM_{i}']['positive_samples'] + 
                       stats[f'PTM_{i}']['negative_samples'] 
                       for i in range(3))
    expected_samples = 3 * 2 * 10 * 3  # 3 batches * 2 batch_size * 10 seq_len * 3 PTMs
    
    print(f"Total samples accumulated: {total_samples}")
    print(f"Expected: {expected_samples}")
    print(f"Correct accumulation: {total_samples == expected_samples}")

In [2]:
test_ptm_metrics()
test_class_weights()
test_edge_cases()

Testing PTMMetrics class...
Test 1: Perfect predictions
Position accuracy: 1.000
Exact match accuracy: 1.000
Overall F1: 1.000
Test 2: Random predictions
Position accuracy: 0.960
Exact match accuracy: 0.060
Overall F1: 0.046
Per-PTM metrics:
PTM_0: F1=0.041, Precision=0.021, Recall=0.500
PTM_1: F1=0.077, Precision=0.040, Recall=1.000
PTM_2: F1=0.035, Precision=0.018, Recall=1.000
PTM_3: F1=0.035, Precision=0.019, Recall=0.333
PTM_4: F1=0.042, Precision=0.021, Recall=1.000
Test 3: Testing with padding mask
Position accuracy: 0.947
Overall F1: 0.095
PTM statistics:
PTM_0: 9 positive, 161 negative (ratio: 0.053)
PTM_1: 9 positive, 161 negative (ratio: 0.053)
PTM_2: 18 positive, 152 negative (ratio: 0.106)
PTM_3: 6 positive, 164 negative (ratio: 0.035)
PTM_4: 8 positive, 162 negative (ratio: 0.047)

Testing calculate_class_weights function...
Class weights (negative, positive):
PTM 0 (rate: 5%): neg=0.039, pos=0.961
PTM 1 (rate: 20%): neg=0.187, pos=0.813
PTM 2 (rate: 50%): neg=0.476, pos=