In [41]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import h5py

from util.util import get_device
device = get_device()
print(f"Using device: {device}")

torch.manual_seed(42)
np.random.seed(42)

PEPTIDE_DATASET_PATH = "data/peptide_dataset.h5"
PEPTIDE_DATASET_LEN = 14774723 # speed up loading
EXTINCT_CLASSIFIER_BEST_PATH = "saved_models/final_extinct_model_1.pt"

from gdiffusion.classifier.extinct_predictor import EsmClassificationHead

Using device: cuda


In [43]:
class PeptideLatentDataset(Dataset):
    def __init__(self, file_loc, latent_dim=256, latent_name='LATENTS', extinct_name='EXTINCT', dataset_len=None, transform=None):
        # keep file open
        self.file = h5py.File(file_loc, 'r')
        self.latent_dataset = self.file[latent_name]
        self.extinct_dataset = self.file[extinct_name]
        self.latent_dim = latent_dim

        self._cached_len = dataset_len if dataset_len is not None else len(self.latent_dataset[:])
        self.transform = transform

    def __len__(self, use_cached=True):
        if use_cached:
            return self._cached_len
        else:
            return len(self.latent_dataset[:])
    
    def __getitem__(self, idx):
        raw_latent = self.latent_dataset[idx]
        raw_extinct = np.array(self.extinct_dataset[idx], dtype=np.int32)

        latent = torch.FloatTensor(data=raw_latent)
        labels = torch.LongTensor(data=raw_extinct)

        out = (latent,  labels)
        if self.transform:
            out = self.transform(out)
        return out
    
class EsmClassificationHead(nn.Module):
    # slightly modified from the original ESM classification head
    def __init__(self, input_dim=256):
        super().__init__()
        self.dense = nn.Linear(input_dim, 2048)
        self.dropout = nn.Dropout(0.05)
        self.dense2 = nn.Linear(2048, 2048)
        self.dense3 = nn.Linear(2048, 2048)
        self.out_proj = nn.Linear(2048, 2)
    
    def forward(self, x):
        x = self.dropout(x)
        x = self.dense(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense2(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.dense3(x)
        x = F.silu(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [44]:
classifier = EsmClassificationHead()
classifier.load_state_dict(torch.load(EXTINCT_CLASSIFIER_BEST_PATH)['model_state_dict'])

  classifier.load_state_dict(torch.load(EXTINCT_CLASSIFIER_BEST_PATH)['model_state_dict'])


<All keys matched successfully>

In [46]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for embeddings, labels in tqdm(train_loader, desc="Training"):
        embeddings, labels = embeddings.to(device), labels.to(device)
        
        optimizer.zero_grad()
        logits = model(embeddings)

        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    return total_loss / len(train_loader)

def evaluate(model, data_loader, device):
    model.eval()
    predictions = []
    true_labels = []
    
    with torch.inference_mode():
        for embeddings, labels in tqdm(data_loader, desc="Evaluating"):
            embeddings = embeddings.to(device)
            logits = model(embeddings)
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
            true_labels.extend(labels.cpu().numpy())
    
    predictions = np.array(predictions)
    true_labels = np.array(true_labels)
    
    return {
        'accuracy': accuracy_score(true_labels, predictions),
        'precision': precision_score(true_labels, predictions),
        'recall': recall_score(true_labels, predictions),
        'f1': f1_score(true_labels, predictions)
    }

print("Loading training data...")
peptide_latent_dataset = PeptideLatentDataset(PEPTIDE_DATASET_PATH, latent_dim=256, dataset_len=PEPTIDE_DATASET_LEN)


# 90/10 split
train_size = int(0.9 * len(peptide_latent_dataset))
test_size = len(peptide_latent_dataset) - train_size

train_dataset, test_dataset = random_split(peptide_latent_dataset, [train_size, test_size])


# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=1024,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
    
test_loader = DataLoader(
    test_dataset,
    batch_size=1024,
    shuffle=True,
    num_workers=0,
    pin_memory=True
)
    
model = EsmClassificationHead().to(device)
criterion = nn.CrossEntropyLoss()  # Changed to CrossEntropyLoss for 2-class output
optimizer = optim.Adam(model.parameters(), lr=1e-3)
    

Loading training data...


In [47]:
sd = torch.load(f="saved_models/final_extinct_model_1.pt")
print(sd.keys())

dict_keys(['model_state_dict', 'optimizer_state_dict', 'final_metrics', 'best_accuracy', 'input_dim'])


  sd = torch.load(f="saved_models/final_extinct_model_1.pt")


In [48]:
model.load_state_dict(sd['model_state_dict'])
optimizer.load_state_dict(sd['optimizer_state_dict'])
metrics = sd['final_metrics']
best_accuracy = sd['best_accuracy']
input_dim = sd['input_dim']

In [61]:
# Weird puzzle, gradient blows up on random but not train?
classifier.eval()
classifier.to('cuda')
with torch.no_grad():
    z, _ = train_dataset[200]
    z = z.to('cuda')
    print(classifier(z))

    print(classifier(torch.randn_like(z, device=z.device, dtype=z.dtype)))

tensor([ 1.7174, -1.6824], device='cuda:0')
tensor([ 11805.5635, -11533.3574], device='cuda:0')


tensor(-3.4326, device='cuda:0', grad_fn=<SelectBackward0>)

In [14]:
metrics = evaluate(model, test_loader, device)

Evaluating:   2%|▏         | 32/1443 [00:07<05:34,  4.21it/s]


KeyboardInterrupt: 

In [None]:
# # Training loop
# num_epochs = 5
# best_accuracy = 0
    
# print("Starting training...")
# for epoch in range(num_epochs):
#     train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
#     metrics = evaluate(model, test_loader, device)
    
#     print(f"\nEpoch {epoch+1}/{num_epochs}")
#     print(f"Train Loss: {train_loss:.4f}")
#     print(f"Test Metrics:")
#     for metric, value in metrics.items():
#         print(f"{metric}: {value:.4f}")
    
#     # Save best model
#     if metrics['accuracy'] > best_accuracy:
#         best_accuracy = metrics['accuracy']
#         torch.save(model.state_dict(), 'train/best_extinct_model.pt')

# # Save final model and training info
# final_save = {
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'final_metrics': metrics,
#     'best_accuracy': best_accuracy,
#     'input_dim': 256
# }
# torch.save(final_save, 'train/final_extinct_model.pt')

# print("\nTraining completed!")
# print(f"Best test accuracy: {best_accuracy:.4f}")


Starting training...


Training: 100%|██████████| 12986/12986 [14:26<00:00, 14.99it/s]
Evaluating: 100%|██████████| 1443/1443 [01:13<00:00, 19.55it/s]



Epoch 1/5
Train Loss: 0.4554
Test Metrics:
accuracy: 0.8056
precision: 0.6932
recall: 0.5790
f1: 0.6310


Training: 100%|██████████| 12986/12986 [13:30<00:00, 16.02it/s]
Evaluating: 100%|██████████| 1443/1443 [01:14<00:00, 19.35it/s]



Epoch 2/5
Train Loss: 0.4127
Test Metrics:
accuracy: 0.8147
precision: 0.6970
recall: 0.6273
f1: 0.6603


Training: 100%|██████████| 12986/12986 [13:25<00:00, 16.13it/s]
Evaluating: 100%|██████████| 1443/1443 [01:12<00:00, 19.89it/s]



Epoch 3/5
Train Loss: 0.3919
Test Metrics:
accuracy: 0.8274
precision: 0.7289
recall: 0.6347
f1: 0.6786


Training: 100%|██████████| 12986/12986 [13:16<00:00, 16.30it/s]
Evaluating: 100%|██████████| 1443/1443 [01:17<00:00, 18.65it/s]



Epoch 4/5
Train Loss: 0.3770
Test Metrics:
accuracy: 0.8326
precision: 0.7305
recall: 0.6608
f1: 0.6939


Training: 100%|██████████| 12986/12986 [13:32<00:00, 15.98it/s]
Evaluating: 100%|██████████| 1443/1443 [01:18<00:00, 18.41it/s]



Epoch 5/5
Train Loss: 0.3620
Test Metrics:
accuracy: 0.8417
precision: 0.7650
recall: 0.6475
f1: 0.7014

Training completed!
Best test accuracy: 0.8417
