In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import h5py

from util.util import get_device
device = get_device()
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

PEPTIDE_DATASET_PATH = "data/peptide_dataset.h5"
PEPTIDE_DATASET_LEN = 14774723 # speed up loading

Using device: cuda


In [4]:
class PeptideLatentDataset(Dataset):
    def __init__(self, file_loc, latent_dim=256, latent_name='LATENTS', extinct_name='EXTINCT', dataset_len=None, transform=None):
        # keep file open
        self.file = h5py.File(file_loc, 'r')
        self.latent_dataset = self.file[latent_name]
        self.extinct_dataset = self.file[extinct_name]
        self.latent_dim = latent_dim

        self._cached_len = dataset_len if dataset_len is not None else len(self.latent_dataset[:])
        self.transform = transform

    def __len__(self, use_cached=True):
        if use_cached:
            return self._cached_len
        else:
            return len(self.latent_dataset[:])
    
    def __getitem__(self, idx):
        raw_latent = self.latent_dataset[idx]
        raw_extinct = np.array(self.extinct_dataset[idx], dtype=np.int32)

        latent = torch.FloatTensor(data=raw_latent)
        labels = torch.LongTensor(data=raw_extinct)

        out = (latent,  labels)
        if self.transform:
            out = self.transform(out)
        return out
    
class EsmClassificationHead(nn.Module):
    # slightly modified from the original ESM classification head
    def __init__(self, input_dim=256):
        super().__init__()
        self.dense = nn.Linear(input_dim, 2048)
        self.dropout = nn.Dropout(0.05)
        self.dense2 = nn.Linear(2048, 2048)
        self.dense3 = nn.Linear(2048, 2048)
        self.out_proj = nn.Linear(2048, 2)
    
    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense3(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [5]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for embeddings, labels in tqdm(train_loader, desc="Training"):
        embeddings, labels = embeddings.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(embeddings)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.inference_mode():
        for embeddings, labels in tqdm(data_loader, desc="Evaluating"):
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    
    return {
        'accuracy': accuracy_score(true_labels, predictions),
        'precision': precision_score(true_labels, predictions),
        'recall': recall_score(true_labels, predictions),
        'f1': f1_score(true_labels, predictions)
    }

print("Loading training data...")
peptide_latent_dataset = PeptideLatentDataset(PEPTIDE_DATASET_PATH, latent_dim=256, dataset_len=PEPTIDE_DATASET_LEN)


# 90/10 split
train_size = int(0.9 * len(peptide_latent_dataset))
test_size = len(peptide_latent_dataset) - train_size

train_dataset, test_dataset = random_split(peptide_latent_dataset, [train_size, test_size])


# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=1024,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
    
test_loader = DataLoader(
    test_dataset,
    batch_size=1024,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
    
model = EsmClassificationHead().to(device)
criterion = nn.CrossEntropyLoss()  # Changed to CrossEntropyLoss for 2-class output
optimizer = optim.Adam(model.parameters(), lr=1e-3)
    

Loading training data...


In [6]:
# Training loop
num_epochs = 10
best_accuracy = 0
    
print("Starting training...")
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    metrics = evaluate(model, test_loader, device)
    
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    print(f"Train Loss: {train_loss:.4f}")
    print(f"Test Metrics:")
    for metric, value in metrics.items():
        print(f"{metric}: {value:.4f}")
    
    # Save best model
    if metrics['accuracy'] > best_accuracy:
        best_accuracy = metrics['accuracy']
        torch.save(model.state_dict(), 'train/best_extinct_model.pt')

# Save final model and training info
final_save = {
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'final_metrics': metrics,
    'best_accuracy': best_accuracy,
    'input_dim': 256
}
torch.save(final_save, 'train/final_extinct_model.pt')

print("\nTraining completed!")
print(f"Best test accuracy: {best_accuracy:.4f}")


Starting training...


Training:   2%|▏         | 254/12986 [00:30<25:08,  8.44it/s]


KeyboardInterrupt: 