In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import gc
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from tqdm.notebook import tqdm
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

In [2]:
def set_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def create_synthetic_dataset(output_dir, num_samples=1000, input_dim=400):
    """Create synthetic protein structure data on CPU"""
    os.makedirs(output_dir, exist_ok=True)
    data = []
    for i in tqdm(range(num_samples), desc="Generating synthetic data"):
        base = np.random.rand(input_dim // 4).astype(np.float32)
        structure = np.tile(base, 4) + 0.05 * np.random.rand(input_dim).astype(np.float32)
        structure += np.sin(np.linspace(0, 8 * np.pi, input_dim)) * 0.1
        structure = (structure - structure.min()) / (structure.max() - structure.min())
        data.append(structure)
        np.save(os.path.join(output_dir, f"protein_{i}.npy"), structure)
    return data

Using device: cuda


In [3]:
class ProteinStructureDataset(Dataset):
    def __init__(self, data_dir):
        self.data_files = [os.path.join(data_dir, f) for f in os.listdir(data_dir) 
                         if f.endswith('.npy')]
    
    def __len__(self):
        return len(self.data_files)
    
    def __getitem__(self, idx):
        data = np.load(self.data_files[idx])
        return torch.tensor(data, dtype=torch.float32)

In [4]:
class HelixSynthModel(nn.Module):
    def __init__(self, input_dim=400, hidden_dim=256, latent_dim=32):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim 
        
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2)
        )
        self.fc_mu = nn.Linear(hidden_dim, latent_dim)
        self.fc_var = nn.Linear(hidden_dim, latent_dim)
        
        self.decoder = nn.Sequential(
            nn.Linear(latent_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.1),
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(hidden_dim, input_dim)
        )
        
        self.time_embed = nn.Sequential(
            nn.Linear(1, latent_dim),  
            nn.SiLU(),
            nn.Linear(latent_dim, latent_dim) 
        )
        
    def encode(self, x):
        h = self.encoder(x)
        return self.fc_mu(h), self.fc_var(h)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z, t=None):
        if t is not None:
            t_emb = self.time_embed(t)
            z = z + t_emb  
        return self.decoder(z)
    
    def forward(self, x, t=None):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z, t)
        return recon, mu, log_var

In [5]:
def vae_diffusion_loss(recon_x, x, mu, log_var, noise_pred, noise, beta=1.0):
    recon_loss = F.mse_loss(recon_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
    diffusion_loss = F.mse_loss(noise_pred, noise)
    return recon_loss + beta * kl_loss + diffusion_loss

In [6]:
def train_model(model, train_loader, val_loader, device, epochs=50, lr=1e-4):
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
    scaler = GradScaler()
    
    train_losses = []
    val_losses = []
    
    for epoch in tqdm(range(epochs), desc="Training"):
        model.train()
        train_loss = 0
        
        for batch in train_loader:
            x = batch.to(device)
            t = torch.rand(x.shape[0], 1, device=device)
            
            noise = torch.randn_like(x)
            x_noisy = x + noise
            
            optimizer.zero_grad()
            with autocast():
                recon_x, mu, log_var = model(x_noisy, t)
                noise_pred = recon_x - x_noisy
                loss = vae_diffusion_loss(recon_x, x, mu, log_var, noise_pred, noise)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            
            train_loss += loss.item()
            
            del x, x_noisy, recon_x, mu, log_var, noise, noise_pred, loss
            torch.cuda.empty_cache()
        
        model.eval()
        val_loss = 0
        with torch.no_grad():
            for batch in val_loader:
                x = batch.to(device)
                t = torch.rand(x.shape[0], 1, device=device)
                noise = torch.randn_like(x)
                x_noisy = x + noise
                
                with autocast():
                    recon_x, mu, log_var = model(x_noisy, t)
                    noise_pred = recon_x - x_noisy
                    loss = vae_diffusion_loss(recon_x, x, mu, log_var, noise_pred, noise)
                
                val_loss += loss.item()
                
                del x, x_noisy, recon_x, mu, log_var, noise, noise_pred, loss
                torch.cuda.empty_cache()
        
        train_losses.append(train_loss / len(train_loader))
        val_losses.append(val_loss / len(val_loader))
        
        print(f"Epoch {epoch+1}: Train Loss: {train_losses[-1]:.4f}, Val Loss: {val_losses[-1]:.4f}")
        
        scheduler.step()
        gc.collect()
    
    return model, train_losses, val_losses

In [7]:
def evaluate_and_plot(model, test_loader, device, output_dir='/kaggle/working/'):
    model.eval()
    reconstructions = []
    originals = []
    latent_vars = []
    
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            x = batch.to(device)
            t = torch.rand(x.shape[0], 1, device=device)
            recon_x, mu, log_var = model(x, t)
            
            reconstructions.append(recon_x.cpu().numpy())
            originals.append(x.cpu().numpy())
            latent_vars.append(mu.cpu().numpy())
            
            del x, recon_x, mu, log_var
            torch.cuda.empty_cache()
    
    reconstructions = np.concatenate(reconstructions)
    originals = np.concatenate(originals)
    latent_vars = np.concatenate(latent_vars)
    
    mse = np.mean((reconstructions - originals) ** 2)
    mae = np.mean(np.abs(reconstructions - originals))
    
    plt.figure(figsize=(15, 10))
    
    plt.subplot(2, 2, 1)
    plt.plot(originals[0], label='Original')
    plt.plot(reconstructions[0], label='Reconstructed')
    plt.title('Sample Reconstruction')
    plt.legend()
    
    plt.subplot(2, 2, 2)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title('Training History')
    plt.legend()
    
    plt.subplot(2, 2, 3)
    plt.scatter(latent_vars[:, 0], latent_vars[:, 1], alpha=0.5)
    plt.title('Latent Space Distribution')
    
    plt.subplot(2, 2, 4)
    sns.histplot((reconstructions - originals).flatten(), bins=50)
    plt.title('Reconstruction Error Distribution')
    
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, 'evaluation_plots.png'))
    plt.close()
    
    metrics = {
        'MSE': float(mse),
        'MAE': float(mae),
        'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S")
    }
    with open(os.path.join(output_dir, 'metrics.json'), 'w') as f:
        json.dump(metrics, f, indent=4)
    
    return metrics

In [8]:
if __name__ == "__main__":
    output_dir = '/kaggle/working/'
    data_dir = os.path.join(output_dir, 'synthetic_data')
    
    print("Generating synthetic data on CPU...")
    synthetic_data = create_synthetic_dataset(data_dir, num_samples=1000)
    
    dataset = ProteinStructureDataset(data_dir)
    train_size = int(0.7 * len(dataset))
    val_size = int(0.15 * len(dataset))
    test_size = len(dataset) - train_size - val_size
    
    train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
        dataset, [train_size, val_size, test_size])
    
    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=2)
    test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
    
    print("\nTraining HelixSynth Model...")
    model = HelixSynthModel()
    model, train_losses, val_losses = train_model(
        model, train_loader, val_loader, device, epochs=75
    )
    
    print("\nEvaluating model...")
    metrics = evaluate_and_plot(model, test_loader, device)
    
    torch.save(model.state_dict(), os.path.join(output_dir, 'helixsynth_model.pt'))
    print(f"\nModel saved to {os.path.join(output_dir, 'helixsynth_model.pt')}")
    print(f"Metrics: MSE: {metrics['MSE']:.4f}, MAE: {metrics['MAE']:.4f}")
    
    gc.collect()
    torch.cuda.empty_cache()

Generating synthetic data on CPU...


Generating synthetic data:   0%|          | 0/1000 [00:00<?, ?it/s]


Training HelixSynth Model...


Training:   0%|          | 0/75 [00:00<?, ?it/s]

Epoch 1: Train Loss: 23716.5000, Val Loss: 10139.2769
Epoch 2: Train Loss: 23570.0234, Val Loss: 10362.2432
Epoch 3: Train Loss: 21607.1960, Val Loss: 10215.4658
Epoch 4: Train Loss: 19417.4718, Val Loss: 9980.4705
Epoch 5: Train Loss: 17484.4744, Val Loss: 9433.8573
Epoch 6: Train Loss: 15828.9615, Val Loss: 8669.8557
Epoch 7: Train Loss: 14353.7583, Val Loss: 8174.7970
Epoch 8: Train Loss: 13214.9430, Val Loss: 7559.5978
Epoch 9: Train Loss: 12123.6750, Val Loss: 6951.1969
Epoch 10: Train Loss: 11225.9003, Val Loss: 6619.2798
Epoch 11: Train Loss: 10503.7304, Val Loss: 6119.0073
Epoch 12: Train Loss: 9903.2173, Val Loss: 5798.3243
Epoch 13: Train Loss: 9313.8175, Val Loss: 5421.6799
Epoch 14: Train Loss: 8871.8610, Val Loss: 5160.4524
Epoch 15: Train Loss: 8494.3141, Val Loss: 4921.2590
Epoch 16: Train Loss: 8092.2455, Val Loss: 4817.9601
Epoch 17: Train Loss: 7830.9969, Val Loss: 4587.9377
Epoch 18: Train Loss: 7537.9782, Val Loss: 4465.9646
Epoch 19: Train Loss: 7296.9778, Val Loss

Evaluating:   0%|          | 0/2 [00:00<?, ?it/s]


Model saved to /kaggle/working/helixsynth_model.pt
Metrics: MSE: 0.0931, MAE: 0.2523
