In [145]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
from sklearn.model_selection import train_test_split

In [146]:
import sys
import os

# Add the project root to Python path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.utils.data import load_and_prepare_data_classification

In [147]:
# load data
data_path = os.path.join(project_root, "data", "processed", "ProSeq_with_5component_analysis.csv")
sequences, labels = load_and_prepare_data_classification(data_path)
print(f"Loaded {len(sequences)} sequences with {len(set(labels))} unique classes")
print(f"Label distribution: {np.bincount(labels)}")


Loaded 8304 sequences with 4 unique classes
Label distribution: [ 657 3656 2295 1696]


In [148]:
# Simplified faster CNN model

class SimplifiedPromoterCNN(nn.Module):
    def __init__(self, num_classes=4, max_length=400):  # Reduced from 600
        super(SimplifiedPromoterCNN, self).__init__()
        # Reduced number of channels and layers for speed
        self.conv1 = nn.Conv1d(4, 16, kernel_size=8, padding=4)  # Smaller kernel, fewer channels
        self.conv2 = nn.Conv1d(16, 32, kernel_size=6, padding=3)  # Smaller kernel, fewer channels
        
        self.pool = nn.MaxPool1d(3)  # Larger pooling for more aggressive downsampling
        self.dropout = nn.Dropout(0.2)  # Reduced dropout
        
        # Calculate the final feature size after convolutions and pooling
        # After 2 conv+pool layers with pool size 3: max_length // (3^2) = max_length // 9
        final_length = max_length // 9
        self.fc1 = nn.Linear(32 * final_length, 64)  # Much smaller FC layer
        self.fc2 = nn.Linear(64, num_classes)
        
    def forward(self, x):
        # Input shape: (batch, length, 4) -> transpose to (batch, 4, length) for Conv1d
        if x.dim() == 3 and x.shape[-1] == 4:
            x = x.transpose(1, 2)
        
        # First conv block
        x = self.pool(F.relu(self.conv1(x)))
        x = self.dropout(x)
        
        # Second conv block  
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)
        
        # Flatten and fully connected layers
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        
        return x


In [149]:
# Get the best available device for training
from src.utils.training import get_best_device

device, loader_kwargs, device_name = get_best_device()
print(f"Using device: {device_name}")

# Create simplified model with reduced sequence length
cnn = SimplifiedPromoterCNN(num_classes=4, max_length=400)
cnn = cnn.to(device)

# Count parameters
total_params = sum(p.numel() for p in cnn.parameters())
trainable_params = sum(p.numel() for p in cnn.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Using device: mps
Total parameters: 94,068
Trainable parameters: 94,068


In [150]:

# Better optimizer and learning rate for faster convergence
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=0.001, momentum=0.9)

# Learning rate scheduler for adaptive learning
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)




In [151]:
# Load data and build loaders
csv_path = "../../data/processed/ProSeq_with_5component_analysis.csv"
sequences, targets = load_and_prepare_data_classification(csv_path)


In [152]:



# Use the classification dataset for integer labels

train_seq, test_seq, train_labels, test_labels = train_test_split(
    sequences, targets, test_size=0.2, random_state=42, stratify=targets
)
train_seq, val_seq, train_labels, val_labels = train_test_split(
    train_seq, train_labels, test_size=0.2, random_state=4, stratify=train_labels
)

# Import the correct classification dataset
from src.utils.data import PromoterClassificationDataset

# Use reduced max_length for faster processing
train_ds = PromoterClassificationDataset(train_seq, train_labels, max_length=400)
val_ds = PromoterClassificationDataset(val_seq, val_labels, max_length=400)
test_ds = PromoterClassificationDataset(test_seq, test_labels, max_length=400)

# Increase batch size for better GPU utilization and faster training
batch_size = 64  # Increased from 32
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, **loader_kwargs)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, **loader_kwargs)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=2, **loader_kwargs)

print(f"Dataset sizes - Train: {len(train_ds)}, Val: {len(val_ds)}, Test: {len(test_ds)}")
print("Sample data:", train_ds[0])


Dataset sizes - Train: 5314, Val: 1329, Test: 1661
Sample data: {'sequence': tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        ...,
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]]), 'target': tensor(3)}


In [153]:
# Efficient training loop with validation and early stopping
from src.utils.training import train_epoch_ce, validate_epoch_ce
import time

num_epochs = 100  # Reduced epochs with early stopping
best_val_loss = float('inf')
patience = 100
patience_counter = 0

print("Starting training...")
start_time = time.time()

for epoch in range(num_epochs):
    epoch_start = time.time()
    
    # Training
    train_loss = train_epoch_ce(cnn, train_loader, criterion, optimizer, device)
    
    # Validation
    val_loss = validate_epoch_ce(cnn, val_loader, criterion, device)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Early stopping check
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        patience_counter = 0
        # Save best model
        torch.save(cnn.state_dict(), 'best_simplified_cnn.pt')
    else:
        patience_counter += 1
    
    epoch_time = time.time() - epoch_start
    
    print(f'Epoch [{epoch+1:2d}/{num_epochs}] '
          f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | '
          f'Time: {epoch_time:.1f}s | LR: {optimizer.param_groups[0]["lr"]:.6f}')
    
    # Early stopping
    if patience_counter >= patience:
        print(f"Early stopping triggered after {epoch+1} epochs")
        break

total_time = time.time() - start_time
print(f'\nTraining completed in {total_time:.1f} seconds')
print(f'Best validation loss: {best_val_loss:.4f}')

# Load best model for evaluation
cnn.load_state_dict(torch.load('best_simplified_cnn.pt'))
print("Loaded best model weights")

Starting training...
Epoch [ 1/100] Train Loss: 1.2937 | Val Loss: 1.2524 | Time: 26.0s | LR: 0.001000
Epoch [ 2/100] Train Loss: 1.2417 | Val Loss: 1.2433 | Time: 26.0s | LR: 0.001000
Epoch [ 3/100] Train Loss: 1.2424 | Val Loss: 1.2436 | Time: 26.0s | LR: 0.001000
Epoch [ 4/100] Train Loss: 1.2484 | Val Loss: 1.2425 | Time: 25.9s | LR: 0.001000
Epoch [ 5/100] Train Loss: 1.2387 | Val Loss: 1.2427 | Time: 26.3s | LR: 0.001000
Epoch [ 6/100] Train Loss: 1.2452 | Val Loss: 1.2433 | Time: 26.2s | LR: 0.001000
Epoch [ 7/100] Train Loss: 1.2499 | Val Loss: 1.2430 | Time: 26.2s | LR: 0.001000
Epoch [ 8/100] Train Loss: 1.2448 | Val Loss: 1.2431 | Time: 26.2s | LR: 0.001000
Epoch [ 9/100] Train Loss: 1.2463 | Val Loss: 1.2422 | Time: 26.0s | LR: 0.001000


KeyboardInterrupt: 

In [156]:
subset = torch.utils.data.Subset(train_ds, range(32))
tiny_loader = DataLoader(subset, batch_size=32, shuffle=True)
for _ in range(200):
    train = train_epoch_ce(cnn, tiny_loader, criterion, optimizer, device)  # expect loss to plummet
    val = validate_epoch_ce(cnn, tiny_loader, criterion, device)

    print(f"Train loss: {train}, Val loss: {val}")
    


Train loss: 1.2383153438568115, Val loss: 1.219427227973938
Train loss: 1.2417333126068115, Val loss: 1.2192087173461914
Train loss: 1.2182636260986328, Val loss: 1.2189810276031494
Train loss: 1.1986217498779297, Val loss: 1.218729853630066
Train loss: 1.2234951257705688, Val loss: 1.2184820175170898
Train loss: 1.217616319656372, Val loss: 1.2182201147079468
Train loss: 1.197298288345337, Val loss: 1.2179462909698486
Train loss: 1.211430549621582, Val loss: 1.2176754474639893
Train loss: 1.2295634746551514, Val loss: 1.2174067497253418
Train loss: 1.2208251953125, Val loss: 1.2171376943588257
Train loss: 1.2183176279067993, Val loss: 1.2168627977371216
Train loss: 1.2358319759368896, Val loss: 1.2165920734405518
Train loss: 1.23792564868927, Val loss: 1.2163264751434326
Train loss: 1.2487882375717163, Val loss: 1.2160704135894775
Train loss: 1.2250200510025024, Val loss: 1.215811848640442
Train loss: 1.2198175191879272, Val loss: 1.2155628204345703
Train loss: 1.2298556566238403, Val

In [None]:
# Evaluate the simplified model
from src.utils.training import evaluate_model_ce
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Get predictions on test set
test_logits, test_labels = evaluate_model_ce(cnn, test_loader, device)
test_predictions = np.argmax(test_logits, axis=1)

# Calculate accuracy
accuracy = accuracy_score(test_labels, test_predictions)
print(f"Test Accuracy: {accuracy:.4f}")

# Classification report
print("\nClassification Report:")
print(classification_report(test_labels, test_predictions, target_names=[f'Class {i}' for i in range(4)]))

# Confusion matrix
cm = confusion_matrix(test_labels, test_predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=[f'Class {i}' for i in range(4)],
            yticklabels=[f'Class {i}' for i in range(4)])
plt.title('Confusion Matrix - Simplified CNN')
plt.ylabel('True Label')
plt.xlabel('Predicted Label')
plt.tight_layout()
plt.savefig('../../plots/simplified_cnn_confusion_matrix.png', dpi=300, bbox_inches='tight')
plt.show()
