# Mount Drive, Installs and Imports

In [21]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
!pip install biopython
!pip install pytorch_wavelets pywavelets
!pip install anndata
!pip install scanpy
!pip install igraph
!pip install leidenalg
!pip install louvain

import os
import sys
import pandas as pd
import numpy as np
import pywt
from pytorch_wavelets import DWTForward, DWTInverse
from pytorch_wavelets import DTCWTForward, DTCWTInverse
import matplotlib.pyplot as plt
import json
import logging
from tqdm import tqdm
from Bio import SeqIO
import traceback
import gc
import re
import tqdm
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import spectral_norm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import psutil
import anndata as ad
import scanpy as sc
import scipy.sparse as sp
from torch.amp import GradScaler, autocast
import umap
import seaborn as sns
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

from scipy.spatial.distance import cdist

import umap
from sklearn.metrics import adjusted_rand_score

from sklearn.utils import resample

%load_ext tensorboard

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


# Sarscov2 S Protein

## (ONLY RUN ONCE - IF SOMEONE ALREADY DOWNLOADED THE DATA, DO NOT RUN THIS AGAIN) Download from NCBI and Save to Drive

In [None]:
# (species: 12,620 total genomes) Severe acute respiratory syndrome-related coronavirus
'''https://www.ncbi.nlm.nih.gov/datasets/taxonomy/694009/'''

# Set the path where you want to save the data
data_dir = '/content/drive/MyDrive/540data/sarscov2'

# Create the directory if it doesn't exist
if not os.path.exists(data_dir):
    os.makedirs(data_dir)

# Change to the data directory
%cd $data_dir

# Download NCBI Datasets command-line tool
!curl -o datasets 'https://ftp.ncbi.nlm.nih.gov/pub/datasets/command-line/v2/linux-amd64/datasets'
!curl -o dataformat 'https://ftp.ncbi.nlm.nih.gov/pub/datasets/command-line/v2/linux-amd64/dataformat'

# Make executable
!chmod +x datasets dataformat

# Download data - specifically including CDS for spike protein
!./datasets download virus protein S taxon 694009 --host human --include cds --filename spikes.zip

# Create directories for extracted data
!mkdir -p spike_data
!mkdir -p cds_sequences

# Unzip all files to examine structure
!unzip spikes.zip -d spike_data

# List the contents of the zip file
!unzip -l spikes.zip

# Extract metadata
!./dataformat tsv virus-genome --inputfile spike_data/ncbi_dataset/data/data_report.jsonl --fields accession,virus-name,host-name,isolate-collection-date > spike_cds_metadata.tsv

# Copy the CDS files to the dedicated directory
!cp spike_data/ncbi_dataset/data/cds.fna cds_sequences/spike_cds.fna

## Data

In [28]:
def read_jsonl(file_path):
    """Reads metadata from JSONL file."""
    with open(file_path, 'r') as f:
        for line in f:
            yield json.loads(line)

def parse_cds_file(cds_file, batch_size=10000):
    """Parse the CDS file and return a dictionary of sequences with isolate information."""
    sequences = {}
    isolates = {}

    # Process in smaller batches to manage memory
    batch_count = 0

    for record in tqdm(SeqIO.parse(cds_file, "fasta"), desc="Parsing CDS file"):
        accession = record.id.split('.')[0]
        sequences[accession] = str(record.seq)

        # Extract isolate information from the description
        # Example: [isolate=SARS-CoV-2/human/USA/UT-UPHL-250228674407/2025]
        isolate_match = re.search(r'\[isolate=(.*?)\]', record.description)
        if isolate_match:
            isolates[accession] = isolate_match.group(1)
        else:
            isolates[accession] = "Unknown"

        batch_count += 1
        if batch_count % batch_size == 0:
            gc.collect()

    return sequences, isolates

def extract_variant_info(entry):
    """Extract variant information from taxonomy/lineage data."""
    variant = "Unknown"

    if 'virus' in entry and 'lineage' in entry['virus']:
        lineage = entry['virus']['lineage']

        # Extract variant name from the taxonomy list
        if isinstance(lineage, list) and lineage:
            # Get the most specific taxon (last in list)
            last_item = lineage[-1]
            if isinstance(last_item, dict) and 'name' in last_item:
                variant = last_item['name']
            else:
                variant = str(last_item)
        else:
            variant = str(lineage)

    # Get other metadata (there is more if we want): https://www.ncbi.nlm.nih.gov/datasets/docs/v2/reference-docs/data-reports/virus/
    collection_date = entry.get('collection_date', 'Unknown')
    geo_location = entry.get('geo_location', 'Unknown')

    return variant, collection_date, geo_location

def process_data(cds_file, metadata_file, top_n=5, samples_per_variant=100, chunk_size=1000):
    """
    Efficiently process CDS and metadata files to create a filtered AnnData object.
    Only processes sequences from top variants and limits samples per variant.
    """
    # Initialize counters and containers
    variant_counts = {}
    target_accessions = {}
    metadata_entries = {}

    logging.info("Processing metadata in a single pass...")

    # Single pass through metadata
    for entry in tqdm.tqdm(read_jsonl(metadata_file), desc="Processing metadata"):
        if entry.get('completeness') != 'COMPLETE' or entry.get('isAnnotated') != True:
            continue

        variant = entry.get('virus', {}).get('pangolinClassification', 'Unknown')
        accession = entry.get('accession', '').split('.')[0]

        # Count variants (for first-pass equivalent)
        variant_counts[variant] = variant_counts.get(variant, 0) + 1

        # Store metadata entry for later use
        metadata_entries[accession] = {
            'variant': variant,
            'species': entry.get('virus', {}).get('organismName', 'Unknown'),
            'accession': accession
        }

    # Determine top variants
    top_variants = sorted(variant_counts.items(), key=lambda x: x[1], reverse=True)[:top_n]
    top_variant_names = [v[0] for v in top_variants]
    logging.info(f"Top {top_n} variants: {top_variant_names}")

    # Initialize containers for selected samples
    selected_accessions = []
    samples_collected = {v: 0 for v in top_variant_names}

    # Select samples for each top variant
    for accession, entry in metadata_entries.items():
        variant = entry['variant']
        if variant in top_variant_names and samples_collected[variant] < samples_per_variant:
            selected_accessions.append(accession)
            samples_collected[variant] += 1

            # If we have enough samples for all variants, stop
            if all(count >= samples_per_variant for count in samples_collected.values()):
                break

    logging.info(f"Collected {len(selected_accessions)} accessions for processing")

    # Extract sequences (unchanged)
    sequences = {}
    for record in tqdm.tqdm(SeqIO.parse(cds_file, "fasta"), desc="Extracting sequences"):
        accession = record.id.split('.')[0]
        if accession in selected_accessions:
            sequences[accession] = str(record.seq)

            # If we have all sequences, stop
            if len(sequences) == len(selected_accessions):
                break

    # Create metadata list from stored entries
    metadata = [
        {
            'accession': acc,
            'variant': metadata_entries[acc]['variant'],
            'species': metadata_entries[acc]['species'],
            'sequence_length': len(sequences[acc])
        }
        for acc in selected_accessions if acc in sequences
    ]

    # Create DataFrame and AnnData with string indices
    metadata_df = pd.DataFrame(metadata)
    metadata_df.index = metadata_df['accession'].astype(str)

    X = sp.csr_matrix((len(metadata_df), 1), dtype=np.float32)
    var_names = ['spike_protein']

    adata = ad.AnnData(X, obs=metadata_df, var=pd.DataFrame(index=var_names))
    adata.layers['sequence'] = np.array([sequences[acc] for acc in metadata_df['accession']]).reshape(-1, 1)

    # Encode categorical variables
    adata.obs['variant'] = adata.obs['variant'].astype('category')
    adata.obs['species'] = adata.obs['species'].astype('category')

    logging.info(f"Created AnnData object with {adata.n_obs} observations")

    return adata


def prepare_for_model(adata, max_seq_length=None, chunk_size=10000):
    """
    Prepare AnnData for model training with sequence truncation/padding.
    """

    logging.info(f"Preparing {adata.n_obs} sequences with chunk size {chunk_size}")

    processed_sequences = []

    # Process sequences in chunks
    for i in range(0, adata.n_obs, chunk_size):

        # Get current chunk
        end_idx = min(i + chunk_size, adata.n_obs)

        # Extract sequence chunk
        sequence_chunk = adata.layers['sequence'][i:end_idx]

        # Process chunk
        chunk_sequences = [seq[0] for seq in sequence_chunk]  # Extract sequence from shape (n, 1)

        # Apply padding/truncation
        if max_seq_length:
            processed_chunk = [
                seq[:max_seq_length] if len(seq) > max_seq_length else seq.ljust(max_seq_length, 'N')
                for seq in chunk_sequences
            ]
        else:
            processed_chunk = chunk_sequences

        # Append processed sequences
        processed_sequences.extend(processed_chunk)

        # Force garbage collection
        sequence_chunk = None
        chunk_sequences = None
        gc.collect()

    # Store all processed sequences
    logging.info(f"Storing {len(processed_sequences)} processed sequences")
    adata.uns['processed_sequences'] = processed_sequences

    return adata

class SARSCoV2VAEDataset(Dataset):
    def __init__(self, sequences, variants=None, max_length=None, one_hot=True, pad_to_power_of_two=True, augment=False, k=3):
        self.variants = variants
        self.one_hot = one_hot
        self.augment = augment
        self.k = k

        # Determine padded length
        self.max_length = max_length or max(len(seq) for seq in sequences)
        if pad_to_power_of_two:
            power = 1
            while power < self.max_length:
                power *= 2
            self.max_length = power
            logging.info(f"Padded sequence length to {self.max_length} (power of 2)")

        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        sequence = sequence[:self.max_length].ljust(self.max_length, 'N')

        if self.augment and np.random.rand() > 0.5:
            sequence = self.kmer_shuffle(sequence, self.k)

        one_hot = self.one_hot_encode_torch(sequence)

        item = {'input': one_hot, 'target': one_hot, 'index': idx}
        if self.variants is not None:
            item['variant'] = self.variants[idx]
        return item

    def one_hot_encode_torch(self, sequence):
        map_dict = {'A': 0, 'T': 1, 'G': 2, 'C': 3}
        seq_indices = torch.tensor([map_dict.get(n, 3) for n in sequence])
        one_hot = torch.zeros(4, self.max_length)
        one_hot.scatter_(0, seq_indices.unsqueeze(0), 1)
        return one_hot

    def kmer_shuffle(self, sequence, k=3):
        kmers = [sequence[i:i+k] for i in range(len(sequence) - k + 1)]
        np.random.shuffle(kmers)
        return ''.join(kmers).ljust(self.max_length, 'N')[:self.max_length]


def prepare_vae_data(adata, max_length=1024, batch_size=32, one_hot=True, train_test_ratio=0.8, augment=True):
    train_idx, test_idx = train_test_split(
        np.arange(adata.n_obs),
        test_size=1 - train_test_ratio,
        stratify=adata.obs['variant_encoded'] if 'variant_encoded' in adata.obs.columns else None,
        random_state=42
    )

    train_dataset = SARSCoV2VAEDataset(
        [adata.uns['processed_sequences'][i] for i in train_idx],
        variants=adata.obs['variant_encoded'].values[train_idx] if 'variant_encoded' in adata.obs.columns else None,
        max_length=max_length,
        one_hot=one_hot,
        augment=augment
    )

    test_dataset = SARSCoV2VAEDataset(
        [adata.uns['processed_sequences'][i] for i in test_idx],
        variants=adata.obs['variant_encoded'].values[test_idx] if 'variant_encoded' in adata.obs.columns else None,
        max_length=max_length,
        one_hot=one_hot,
        augment=False
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=os.cpu_count(), pin_memory=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=os.cpu_count(), pin_memory=True)

    return train_loader, test_loader

## Model

In [29]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim):
        super().__init__()
        self.attention = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Tanh(),
            nn.Linear(input_dim, 1)
        )

    def forward(self, x):
        weights = F.softmax(self.attention(x), dim=1)
        return (weights * x).sum(dim=1)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, dropout_prob=0.3):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_prob),
            nn.Conv1d(out_channels, out_channels, kernel_size=3, padding=1)
        )
        self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        return F.leaky_relu(self.conv(x) + self.shortcut(x), 0.2)

class KLAnnealer:
    def __init__(self, total_epochs, anneal_type='logistic'):
        self.epoch = 0
        self.total = total_epochs
        self.anneal_type = anneal_type

    def __call__(self):
        if self.anneal_type == 'linear':
            return min(1.0, self.epoch/self.total*4)
        elif self.anneal_type == 'logistic':
            return 1/(1 + np.exp(-0.05*(self.epoch - self.total//2)))
        return 1.0

class GenomicVAE(nn.Module):
    def __init__(self, sequence_length, num_nucleotides=4, latent_dim=256, hidden_dims=[512, 256], dropout_prob=0.1):
        super().__init__()
        self.sequence_length = sequence_length
        self.num_nucleotides = num_nucleotides
        self.latent_dim = latent_dim

        self.flat_dim = hidden_dims[-1] * sequence_length

        # Encoder with Dropout
        self.encoder = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv1d(num_nucleotides, hidden_dims[0], kernel_size=3, padding=1)),
            nn.LeakyReLU(0.2),
            nn.Dropout(dropout_prob),
            ResidualBlock(hidden_dims[0], hidden_dims[1], dropout_prob)
        )

        self.flatten = nn.Flatten()

        # Attention mechanism for encoder
        self.attention = AttentionLayer(hidden_dims[-1])

        # Latent space
        self.fc_mu = nn.Linear(self.flat_dim, latent_dim)
        self.fc_var = nn.Linear(self.flat_dim, latent_dim)

        # Decoder
        self.decoder_fc = nn.Linear(latent_dim, self.flat_dim)

        # Decoder conv layers
        self.decoder_conv = nn.Sequential(
            ResidualBlock(hidden_dims[-1], hidden_dims[-1], dropout_prob),
            ResidualBlock(hidden_dims[-1], hidden_dims[0], dropout_prob),
            nn.Conv1d(hidden_dims[0], num_nucleotides, kernel_size=3, padding=1)
        )

        self.kl_annealer = KLAnnealer(100, 'logistic')

    def encode(self, x):
        h = self.encoder(x)
        h_flat = self.flatten(h)
        return self.fc_mu(h_flat), self.fc_var(h_flat)

    def decode(self, z):
        h = self.decoder_fc(z)
        h = h.view(-1, hidden_dims[-1], self.sequence_length)
        h = F.dropout(h, p=0.1, training=self.training)
        return self.decoder_conv(h)

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        return self.decode(z), mu, log_var

    def loss_function(self, recon_x, x, mu, log_var, beta=0.5):
        recon_loss = F.binary_cross_entropy_with_logits(recon_x, x, reduction='sum')
        kl_div = -0.5 * torch.sum(1 + log_var - mu.pow(2) - log_var.exp())
        return recon_loss + self.kl_annealer() * beta * kl_div

def train_vae(model, train_loader, test_loader, epochs=50, lr=1e-3, beta=1.0, device='cuda', use_mixed_precision=True):
    """Train the model with proper mixed precision handling"""

    device = torch.device(device if torch.cuda.is_available() else "cpu")
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=5)

    scaler = GradScaler(enabled=use_mixed_precision)

    # logs directory
    log_dir = os.path.join(output_dir, 'logs')
    os.makedirs(log_dir, exist_ok=True)
    monitor = TrainingMonitor(log_dir)

    train_losses = []
    test_losses = []

    for epoch in tqdm.trange(epochs, desc="Training Progress"):
        # Training phase
        model.train()
        train_loss = 0.0
        current_lr = scheduler.optimizer.param_groups[0]['lr']
        logging.info(f"Current LR: {current_lr}")

        for batch in train_loader:
            optimizer.zero_grad()
            inputs = batch['input'].to(device)

            with autocast(device_type='cuda' if device == 'cuda' else 'cpu', enabled=use_mixed_precision):
                recon, mu, logvar = model(inputs)
                loss = model.loss_function(recon, inputs, mu, logvar, beta=beta)

            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()


            train_loss += loss.item()
            model.kl_annealer.epoch += 1

        # Validation phase
        model.eval()
        test_loss = 0.0

        with torch.no_grad():
            for batch in test_loader:
                inputs = batch['input'].to(device)

                with autocast(device_type='cuda' if device == 'cuda' else 'cpu', enabled=use_mixed_precision):
                    recon, mu, logvar = model(inputs)
                    loss = model.loss_function(recon, inputs, mu, logvar, beta=beta)

                test_loss += loss.item()

        # Log metrics
        avg_train_loss = train_loss / len(train_loader)
        avg_test_loss = test_loss / len(test_loader)
        monitor.log('Loss/Train', avg_train_loss, epoch)
        monitor.log('Loss/Test', avg_test_loss, epoch)

        scheduler.step(avg_test_loss)

        train_losses.append(avg_train_loss)
        test_losses.append(avg_test_loss)

        # Print progress
        tqdm.tqdm.write(f"Epoch {epoch+1}/{epochs} - Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}")

        # Memory cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    return model, train_losses, test_losses

class TrainingMonitor:
    def __init__(self, log_dir):
        self.writer = SummaryWriter(log_dir)
        self.metrics = dict()

    def log(self, name, value, epoch):
        self.writer.add_scalar(name, value, epoch)
        if name not in self.metrics:
            self.metrics[name] = []
        self.metrics[name].append(value)

    def plot_metrics(self):
        plt.figure(figsize=(12, 6))
        for name, values in self.metrics.items():
            if len(values) > 0:
                plt.plot(values, label=name)
        if any(len(v) > 0 for v in self.metrics.values()):
            plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.xlabel('Epoch')
        plt.ylabel('Value')
        plt.title('Training Metrics')
        plt.savefig(os.path.join(output_dir, 'training_metrics.png'))
        plt.close()

## Analysis

In [30]:
def extract_latent_vectors(model, data_loader, adata=None, device='cuda'):
    """Extract latent space representations from the trained VAE."""
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    model.to(device)
    model.eval()

    latent_vectors = []
    dataset_indices = []

    with torch.no_grad():
        for batch_idx, batch in enumerate(tqdm.tqdm(data_loader, desc="Extracting latent vectors")):
            inputs = batch['input'].to(device)
            mu, _ = model.encode(inputs)
            latent_vectors.append(mu.cpu().numpy())

            batch_indices = batch['index'].cpu().numpy() if 'index' in batch else range(
                batch_idx * data_loader.batch_size,
                min((batch_idx + 1) * data_loader.batch_size, len(data_loader.dataset))
            )
            dataset_indices.extend(batch_indices)

    latent_vectors = np.vstack(latent_vectors)

    if adata is not None:
        metadata_df = adata.obs.iloc[dataset_indices].copy()
        return latent_vectors, metadata_df
    else:
        return latent_vectors, None

def perform_clustering_and_visualization(latent_vectors, metadata, output_dir, model_name, n_neighbors=30, metric='correlation'):
    # Create AnnData object
    adata = sc.AnnData(latent_vectors)
    adata.obs = metadata

    # Compute neighborhood graph with correlation metric
    sc.pp.neighbors(adata, n_neighbors=n_neighbors, use_rep='X', metric=metric, method='umap', knn=True)

    # Run Leiden
    resolutions = np.linspace(0.1, 2.0, 20)
    ari_scores = []

    for res in resolutions:
        sc.tl.leiden(
            adata,
            resolution=res,
            key_added=f'leiden_{res:.2f}',
            flavor='igraph',
            n_iterations=2,
            directed=False
        )
        ari = adjusted_rand_score(adata.obs['variant'], adata.obs[f'leiden_{res:.2f}'])
        ari_scores.append(ari)

    best_res = resolutions[np.argmax(ari_scores)]
    best_ari = max(ari_scores)
    adata.obs['leiden'] = adata.obs[f'leiden_{best_res:.2f}']

    sc.tl.umap(adata)

    fig, axs = plt.subplots(1, 2, figsize=(20, 10))
    sc.pl.umap(adata, color='variant', ax=axs[0], show=False)
    sc.pl.umap(adata, color='leiden', ax=axs[1], show=False)
    plt.suptitle(f"{model_name} - UMAP (ARI: {best_ari:.4f})")
    plt.savefig(os.path.join(output_dir, f'{model_name.lower()}_umap.png'))
    plt.close()

    return best_ari


## Main

In [31]:
# cler gpu and cpu
torch.cuda.empty_cache()
gc.collect()

54

In [32]:
# Setup logging
for handler in logging.root.handlers[:]:
    logging.root.removeHandler(handler)

logging.basicConfig(
    level=logging.DEBUG if "--debug" in sys.argv else logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[
        logging.StreamHandler(sys.stdout),
    ],
)

# GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [33]:
# DATA FROM NCBI
data_dir = '/content/drive/MyDrive/540data/sarscov2'
metadata_file = os.path.join(data_dir, 'spike_cds_metadata.tsv')
cds_file = os.path.join(data_dir, 'cds_sequences/spike_cds.fna')
data_report_jsonl = os.path.join(data_dir, 'spike_data/ncbi_dataset/data/data_report.jsonl')

# EXPERIMENT
output_dir = os.path.join(data_dir, 'results')
os.makedirs(output_dir, exist_ok=True)

# Set up training monitor
log_dir = os.path.join(output_dir, 'tensorboard_logs')
os.makedirs(log_dir, exist_ok=True)
training_monitor = TrainingMonitor(log_dir)

# Initialize data loaders
train_loader_path = os.path.join(output_dir, 'train_loader.pt')
test_loader_path = os.path.join(output_dir, 'test_loader.pt')

# Model
conv_model_path = os.path.join(output_dir, 'sars_cov2_conv_vae.pt')

In [34]:
# EXPERIMENT PARAMETERS
chunk_size = 1000  # For data processing
top_n = 3 # Top variants by sample count
samples_per_variant = 100  # Samples per variant - reduced from 10000 to be more realistic
max_sequence_length = 2048 # power of 2
num_nucleotides = 4 # A, C, G, T

In [35]:
# MODEL PARAMETERS
latent_dim = 128
hidden_dims = [256, 128]

batch_size = 128
epochs = 100
lr = 1e-4
beta = 2.0
dropout_prob = .5

use_mixed_precision = True

# clustering viz
n_neighbors=30
metric='correlation'

In [36]:
# Process data efficiently - directly get filtered data
filtered_data_path = os.path.join(output_dir, 'filtered_data.h5ad')
if not os.path.exists(filtered_data_path):
    logging.info("Processing and filtering data...")
    adata_filtered = process_data(
        cds_file,
        data_report_jsonl,
        top_n=top_n,
        samples_per_variant=samples_per_variant,
        chunk_size=chunk_size
    )

    if adata_filtered is None:
        logging.error("Failed to process data. Exiting.")
        sys.exit(1)

    # Prepare sequences for model
    adata_filtered = prepare_for_model(adata_filtered, max_seq_length=max_sequence_length)
    adata_filtered.write(filtered_data_path)
else:
    logging.info("Loading filtered data from file...")
    adata_filtered = sc.read(filtered_data_path)

# Log dataset information
logging.info(f"Filtered dataset size: {adata_filtered.n_obs} sequences")
logging.info(f"Samples per variant:\n{adata_filtered.obs['variant'].value_counts()}")
logging.info(f"Unique species: {adata_filtered.obs['species'].nunique()}")
logging.info(f"Species distribution:\n{adata_filtered.obs['species'].value_counts()}")
logging.info(f"Unique variants: {adata_filtered.obs['variant'].nunique()}")
logging.info(f"Variant distribution (top {top_n}):\n{adata_filtered.obs['variant'].value_counts().head(top_n)}")

# Free up memory after data processing
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

if os.path.exists(train_loader_path) and os.path.exists(test_loader_path):
    train_loader = torch.load(train_loader_path, weights_only=False)
    test_loader = torch.load(test_loader_path, weights_only=False)
    logging.info("Loaded existing data loaders")
else:
    train_loader, test_loader = prepare_vae_data(
        adata_filtered,
        max_length=max_sequence_length,
        batch_size=batch_size
    )
    torch.save(train_loader, train_loader_path)
    torch.save(test_loader, test_loader_path)
    logging.info("Created and saved new data loaders")

logging.info(f"Train loader size: {len(train_loader)} batches")
logging.info(f"Test loader size: {len(test_loader)} batches")

2025-03-26 23:06:37,708 - INFO - Processing and filtering data...
2025-03-26 23:06:37,710 - INFO - Processing metadata in a single pass...


Processing metadata: 9088286it [06:57, 21743.51it/s]

2025-03-26 23:13:35,691 - INFO - Top 3 variants: ['BA.1.1', 'B.1.1.7', 'BA.2.12.1']





2025-03-26 23:13:35,874 - INFO - Collected 300 accessions for processing


Extracting sequences: 2779175it [01:33, 29864.75it/s]

2025-03-26 23:15:09,181 - INFO - Created AnnData object with 300 observations
2025-03-26 23:15:09,412 - INFO - Preparing 300 sequences with chunk size 10000





2025-03-26 23:15:09,909 - INFO - Storing 300 processed sequences
2025-03-26 23:15:10,004 - INFO - Filtered dataset size: 300 sequences
2025-03-26 23:15:10,013 - INFO - Samples per variant:
variant
B.1.1.7      100
BA.1.1       100
BA.2.12.1    100
Name: count, dtype: int64
2025-03-26 23:15:10,017 - INFO - Unique species: 1
2025-03-26 23:15:10,018 - INFO - Species distribution:
species
Severe acute respiratory syndrome coronavirus 2    300
Name: count, dtype: int64
2025-03-26 23:15:10,020 - INFO - Unique variants: 3
2025-03-26 23:15:10,023 - INFO - Variant distribution (top 3):
variant
B.1.1.7      100
BA.1.1       100
BA.2.12.1    100
Name: count, dtype: int64
2025-03-26 23:15:10,578 - INFO - Padded sequence length to 2048 (power of 2)
2025-03-26 23:15:10,580 - INFO - Padded sequence length to 2048 (power of 2)
2025-03-26 23:15:10,605 - INFO - Created and saved new data loaders
2025-03-26 23:15:10,607 - INFO - Train loader size: 2 batches
2025-03-26 23:15:10,610 - INFO - Test loader si

In [37]:
'''Initialize Model'''
# Initialize conventional model
conv_model = GenomicVAE(
    sequence_length=max_sequence_length,
    num_nucleotides=num_nucleotides,
    latent_dim=latent_dim,
    hidden_dims=hidden_dims,
    dropout_prob=dropout_prob
).to(device)

# Free up memory
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

'''Train Model'''
# Train model with augmentation
if not os.path.exists(conv_model_path):
    logging.info(f"Training Conventional VAE with augmentation")

    # Create training monitor
    monitor = TrainingMonitor(os.path.join(log_dir, 'conv_vae'))

    # Train with augmentation
    conv_model, conv_train_losses, conv_test_losses = train_vae(
        conv_model,
        train_loader,
        test_loader,
        epochs=epochs,
        beta=beta,
        device=device,
        lr=lr,
        use_mixed_precision=use_mixed_precision
    )

    # Save the trained model
    torch.save(conv_model.state_dict(), conv_model_path)

    # Plot training curves
    plt.figure(figsize=(12, 6))
    plt.plot(conv_train_losses, label='Train Loss')
    plt.plot(conv_test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE Training Progress')
    plt.legend()
    plt.savefig(os.path.join(output_dir, 'training_curves.png'))
    plt.close()

    # Create Tensorboard training visualization
    monitor.plot_metrics()
else:
    logging.info(f"Loading trained conventional model from {conv_model_path}")
    conv_model.load_state_dict(torch.load(conv_model_path, map_location=device))
    conv_model.to(device)


2025-03-26 23:15:24,320 - INFO - Training Conventional VAE with augmentation


Training Progress:   0%|          | 0/100 [00:00<?, ?it/s]

2025-03-26 23:15:24,361 - INFO - Current LR: 0.0001


Training Progress:   1%|          | 1/100 [00:01<02:36,  1.58s/it]

Epoch 1/100 - Train Loss: 666950.3750, Test Loss: 328862.9688
2025-03-26 23:15:25,942 - INFO - Current LR: 0.0001


Training Progress:   2%|▏         | 2/100 [00:03<02:31,  1.55s/it]

Epoch 2/100 - Train Loss: 627089.5625, Test Loss: 322755.7812
2025-03-26 23:15:27,466 - INFO - Current LR: 0.0001


Training Progress:   3%|▎         | 3/100 [00:04<02:42,  1.68s/it]

Epoch 3/100 - Train Loss: 593568.8750, Test Loss: 310034.1250
2025-03-26 23:15:29,296 - INFO - Current LR: 0.0001


Training Progress:   4%|▍         | 4/100 [00:06<02:45,  1.72s/it]

Epoch 4/100 - Train Loss: 565535.1562, Test Loss: 302161.0938
2025-03-26 23:15:31,089 - INFO - Current LR: 0.0001


Training Progress:   5%|▌         | 5/100 [00:08<02:38,  1.67s/it]

Epoch 5/100 - Train Loss: 547441.4375, Test Loss: 298745.0625
2025-03-26 23:15:32,663 - INFO - Current LR: 0.0001


Training Progress:   6%|▌         | 6/100 [00:10<02:43,  1.74s/it]

Epoch 6/100 - Train Loss: 534252.5625, Test Loss: 290988.4062
2025-03-26 23:15:34,546 - INFO - Current LR: 0.0001


Training Progress:   7%|▋         | 7/100 [00:12<02:55,  1.89s/it]

Epoch 7/100 - Train Loss: 527214.6406, Test Loss: 284808.5938
2025-03-26 23:15:36,744 - INFO - Current LR: 0.0001


Training Progress:   8%|▊         | 8/100 [00:13<02:43,  1.77s/it]

Epoch 8/100 - Train Loss: 525373.2031, Test Loss: 281194.8750
2025-03-26 23:15:38,267 - INFO - Current LR: 0.0001


Training Progress:   9%|▉         | 9/100 [00:15<02:34,  1.69s/it]

Epoch 9/100 - Train Loss: 512674.3438, Test Loss: 275804.4375
2025-03-26 23:15:39,781 - INFO - Current LR: 0.0001


Training Progress:  10%|█         | 10/100 [00:17<02:32,  1.69s/it]

Epoch 10/100 - Train Loss: 505088.4688, Test Loss: 267089.1562
2025-03-26 23:15:41,469 - INFO - Current LR: 0.0001


Training Progress:  11%|█         | 11/100 [00:19<02:36,  1.76s/it]

Epoch 11/100 - Train Loss: 490216.1875, Test Loss: 258337.6406
2025-03-26 23:15:43,382 - INFO - Current LR: 0.0001


Training Progress:  12%|█▏        | 12/100 [00:20<02:32,  1.73s/it]

Epoch 12/100 - Train Loss: 478967.1562, Test Loss: 248226.1406
2025-03-26 23:15:45,050 - INFO - Current LR: 0.0001


Training Progress:  13%|█▎        | 13/100 [00:22<02:25,  1.67s/it]

Epoch 13/100 - Train Loss: 465405.0156, Test Loss: 237350.6406
2025-03-26 23:15:46,582 - INFO - Current LR: 0.0001


Training Progress:  14%|█▍        | 14/100 [00:23<02:19,  1.63s/it]

Epoch 14/100 - Train Loss: 452016.3594, Test Loss: 225147.1562
2025-03-26 23:15:48,107 - INFO - Current LR: 0.0001


Training Progress:  15%|█▌        | 15/100 [00:25<02:15,  1.59s/it]

Epoch 15/100 - Train Loss: 425335.7031, Test Loss: 212618.3750
2025-03-26 23:15:49,621 - INFO - Current LR: 0.0001


Training Progress:  16%|█▌        | 16/100 [00:26<02:12,  1.58s/it]

Epoch 16/100 - Train Loss: 435681.2812, Test Loss: 201251.4531
2025-03-26 23:15:51,175 - INFO - Current LR: 0.0001


Training Progress:  17%|█▋        | 17/100 [00:28<02:09,  1.56s/it]

Epoch 17/100 - Train Loss: 415537.7656, Test Loss: 192810.4688
2025-03-26 23:15:52,699 - INFO - Current LR: 0.0001


Training Progress:  18%|█▊        | 18/100 [00:29<02:07,  1.55s/it]

Epoch 18/100 - Train Loss: 424499.5000, Test Loss: 182428.5781
2025-03-26 23:15:54,218 - INFO - Current LR: 0.0001


Training Progress:  19%|█▉        | 19/100 [00:31<02:13,  1.65s/it]

Epoch 19/100 - Train Loss: 396755.9688, Test Loss: 169236.9375
2025-03-26 23:15:56,101 - INFO - Current LR: 0.0001


Training Progress:  20%|██        | 20/100 [00:33<02:14,  1.68s/it]

Epoch 20/100 - Train Loss: 369424.2656, Test Loss: 153725.0625
2025-03-26 23:15:57,854 - INFO - Current LR: 0.0001


Training Progress:  21%|██        | 21/100 [00:35<02:09,  1.64s/it]

Epoch 21/100 - Train Loss: 370099.9531, Test Loss: 144511.4844
2025-03-26 23:15:59,414 - INFO - Current LR: 0.0001


Training Progress:  22%|██▏       | 22/100 [00:36<02:05,  1.61s/it]

Epoch 22/100 - Train Loss: 355301.3438, Test Loss: 132896.0781
2025-03-26 23:16:00,944 - INFO - Current LR: 0.0001


Training Progress:  23%|██▎       | 23/100 [00:38<02:02,  1.59s/it]

Epoch 23/100 - Train Loss: 344333.5938, Test Loss: 126006.8906
2025-03-26 23:16:02,497 - INFO - Current LR: 0.0001


Training Progress:  24%|██▍       | 24/100 [00:39<02:01,  1.60s/it]

Epoch 24/100 - Train Loss: 341850.8594, Test Loss: 117852.1562
2025-03-26 23:16:04,110 - INFO - Current LR: 0.0001


Training Progress:  25%|██▌       | 25/100 [00:41<01:58,  1.59s/it]

Epoch 25/100 - Train Loss: 350067.5469, Test Loss: 106548.9844
2025-03-26 23:16:05,664 - INFO - Current LR: 0.0001


Training Progress:  26%|██▌       | 26/100 [00:42<01:56,  1.57s/it]

Epoch 26/100 - Train Loss: 346841.3281, Test Loss: 95483.8672
2025-03-26 23:16:07,199 - INFO - Current LR: 0.0001


Training Progress:  27%|██▋       | 27/100 [00:44<01:57,  1.61s/it]

Epoch 27/100 - Train Loss: 350743.7188, Test Loss: 92542.2969
2025-03-26 23:16:08,895 - INFO - Current LR: 0.0001


Training Progress:  28%|██▊       | 28/100 [00:46<02:02,  1.70s/it]

Epoch 28/100 - Train Loss: 341188.5781, Test Loss: 78751.8750
2025-03-26 23:16:10,804 - INFO - Current LR: 0.0001


Training Progress:  29%|██▉       | 29/100 [00:47<01:56,  1.65s/it]

Epoch 29/100 - Train Loss: 311270.0625, Test Loss: 77187.4844
2025-03-26 23:16:12,327 - INFO - Current LR: 0.0001


Training Progress:  30%|███       | 30/100 [00:49<01:53,  1.61s/it]

Epoch 30/100 - Train Loss: 325057.6094, Test Loss: 67936.1719
2025-03-26 23:16:13,869 - INFO - Current LR: 0.0001


Training Progress:  31%|███       | 31/100 [00:51<01:49,  1.59s/it]

Epoch 31/100 - Train Loss: 330025.2969, Test Loss: 65867.9297
2025-03-26 23:16:15,397 - INFO - Current LR: 0.0001


Training Progress:  32%|███▏      | 32/100 [00:52<01:47,  1.57s/it]

Epoch 32/100 - Train Loss: 350835.2344, Test Loss: 50852.3516
2025-03-26 23:16:16,940 - INFO - Current LR: 0.0001


Training Progress:  33%|███▎      | 33/100 [00:54<01:44,  1.56s/it]

Epoch 33/100 - Train Loss: 329409.9531, Test Loss: 51152.9102
2025-03-26 23:16:18,474 - INFO - Current LR: 0.0001


Training Progress:  34%|███▍      | 34/100 [00:55<01:42,  1.56s/it]

Epoch 34/100 - Train Loss: 306454.7969, Test Loss: 50755.5117
2025-03-26 23:16:20,022 - INFO - Current LR: 0.0001


Training Progress:  35%|███▌      | 35/100 [00:57<01:45,  1.62s/it]

Epoch 35/100 - Train Loss: 311897.4219, Test Loss: 44589.1133
2025-03-26 23:16:21,799 - INFO - Current LR: 0.0001


Training Progress:  36%|███▌      | 36/100 [00:59<01:48,  1.70s/it]

Epoch 36/100 - Train Loss: 310917.6562, Test Loss: 51761.4258
2025-03-26 23:16:23,684 - INFO - Current LR: 0.0001


Training Progress:  37%|███▋      | 37/100 [01:01<01:47,  1.70s/it]

Epoch 37/100 - Train Loss: 310280.3281, Test Loss: 52754.9570
2025-03-26 23:16:25,380 - INFO - Current LR: 0.0001


Training Progress:  38%|███▊      | 38/100 [01:02<01:42,  1.65s/it]

Epoch 38/100 - Train Loss: 310967.2500, Test Loss: 46646.0781
2025-03-26 23:16:26,897 - INFO - Current LR: 0.0001


Training Progress:  39%|███▉      | 39/100 [01:04<01:38,  1.62s/it]

Epoch 39/100 - Train Loss: 323374.9844, Test Loss: 45173.8438
2025-03-26 23:16:28,442 - INFO - Current LR: 0.0001


Training Progress:  40%|████      | 40/100 [01:05<01:35,  1.59s/it]

Epoch 40/100 - Train Loss: 310652.2188, Test Loss: 47454.9180
2025-03-26 23:16:29,967 - INFO - Current LR: 0.0001


Training Progress:  41%|████      | 41/100 [01:07<01:33,  1.59s/it]

Epoch 41/100 - Train Loss: 288822.5312, Test Loss: 47090.2500
2025-03-26 23:16:31,550 - INFO - Current LR: 5e-05


Training Progress:  42%|████▏     | 42/100 [01:08<01:31,  1.57s/it]

Epoch 42/100 - Train Loss: 320716.5781, Test Loss: 48652.7266
2025-03-26 23:16:33,088 - INFO - Current LR: 5e-05


Training Progress:  43%|████▎     | 43/100 [01:10<01:32,  1.62s/it]

Epoch 43/100 - Train Loss: 282208.3438, Test Loss: 47765.5547
2025-03-26 23:16:34,821 - INFO - Current LR: 5e-05


Training Progress:  44%|████▍     | 44/100 [01:12<01:36,  1.72s/it]

Epoch 44/100 - Train Loss: 323019.6094, Test Loss: 49190.0469
2025-03-26 23:16:36,787 - INFO - Current LR: 5e-05


Training Progress:  45%|████▌     | 45/100 [01:14<01:34,  1.72s/it]

Epoch 45/100 - Train Loss: 307025.9688, Test Loss: 50601.7734
2025-03-26 23:16:38,480 - INFO - Current LR: 5e-05


Training Progress:  46%|████▌     | 46/100 [01:15<01:29,  1.66s/it]

Epoch 46/100 - Train Loss: 323937.9375, Test Loss: 42531.5391
2025-03-26 23:16:40,020 - INFO - Current LR: 5e-05


Training Progress:  47%|████▋     | 47/100 [01:17<01:31,  1.72s/it]

Epoch 47/100 - Train Loss: 319324.8438, Test Loss: 40165.2734
2025-03-26 23:16:41,867 - INFO - Current LR: 5e-05


Training Progress:  48%|████▊     | 48/100 [01:19<01:30,  1.74s/it]

Epoch 48/100 - Train Loss: 325115.5938, Test Loss: 44257.4648
2025-03-26 23:16:43,667 - INFO - Current LR: 5e-05


Training Progress:  49%|████▉     | 49/100 [01:20<01:25,  1.68s/it]

Epoch 49/100 - Train Loss: 291972.0469, Test Loss: 53814.9648
2025-03-26 23:16:45,190 - INFO - Current LR: 5e-05


Training Progress:  50%|█████     | 50/100 [01:22<01:21,  1.63s/it]

Epoch 50/100 - Train Loss: 290171.9062, Test Loss: 49412.0938
2025-03-26 23:16:46,710 - INFO - Current LR: 5e-05


Training Progress:  51%|█████     | 51/100 [01:24<01:21,  1.67s/it]

Epoch 51/100 - Train Loss: 284170.2969, Test Loss: 43621.7500
2025-03-26 23:16:48,459 - INFO - Current LR: 5e-05


Training Progress:  52%|█████▏    | 52/100 [01:25<01:22,  1.73s/it]

Epoch 52/100 - Train Loss: 320223.1562, Test Loss: 47418.1484
2025-03-26 23:16:50,332 - INFO - Current LR: 5e-05


Training Progress:  53%|█████▎    | 53/100 [01:27<01:20,  1.71s/it]

Epoch 53/100 - Train Loss: 303594.2812, Test Loss: 54136.9805
2025-03-26 23:16:52,004 - INFO - Current LR: 2.5e-05


Training Progress:  54%|█████▍    | 54/100 [01:29<01:16,  1.66s/it]

Epoch 54/100 - Train Loss: 294998.4219, Test Loss: 50862.8867
2025-03-26 23:16:53,544 - INFO - Current LR: 2.5e-05


Training Progress:  55%|█████▌    | 55/100 [01:30<01:12,  1.62s/it]

Epoch 55/100 - Train Loss: 281516.1875, Test Loss: 48519.6719
2025-03-26 23:16:55,070 - INFO - Current LR: 2.5e-05


Training Progress:  56%|█████▌    | 56/100 [01:32<01:10,  1.60s/it]

Epoch 56/100 - Train Loss: 308715.3594, Test Loss: 48811.2344
2025-03-26 23:16:56,618 - INFO - Current LR: 2.5e-05


Training Progress:  57%|█████▋    | 57/100 [01:33<01:07,  1.57s/it]

Epoch 57/100 - Train Loss: 335542.2031, Test Loss: 44974.2266
2025-03-26 23:16:58,132 - INFO - Current LR: 2.5e-05


Training Progress:  58%|█████▊    | 58/100 [01:35<01:05,  1.56s/it]

Epoch 58/100 - Train Loss: 295767.9375, Test Loss: 44692.5234
2025-03-26 23:16:59,677 - INFO - Current LR: 2.5e-05


Training Progress:  59%|█████▉    | 59/100 [01:36<01:05,  1.60s/it]

Epoch 59/100 - Train Loss: 328523.8281, Test Loss: 49760.9297
2025-03-26 23:17:01,353 - INFO - Current LR: 1.25e-05


Training Progress:  60%|██████    | 60/100 [01:38<01:06,  1.67s/it]

Epoch 60/100 - Train Loss: 318929.4688, Test Loss: 49152.6289
2025-03-26 23:17:03,176 - INFO - Current LR: 1.25e-05


Training Progress:  61%|██████    | 61/100 [01:40<01:05,  1.67s/it]

Epoch 61/100 - Train Loss: 295361.2812, Test Loss: 48258.5039
2025-03-26 23:17:04,870 - INFO - Current LR: 1.25e-05


Training Progress:  62%|██████▏   | 62/100 [01:42<01:01,  1.63s/it]

Epoch 62/100 - Train Loss: 297420.8438, Test Loss: 50729.6875
2025-03-26 23:17:06,382 - INFO - Current LR: 1.25e-05


Training Progress:  63%|██████▎   | 63/100 [01:43<00:59,  1.59s/it]

Epoch 63/100 - Train Loss: 285154.3750, Test Loss: 50041.8828
2025-03-26 23:17:07,904 - INFO - Current LR: 1.25e-05


Training Progress:  64%|██████▍   | 64/100 [01:45<00:56,  1.58s/it]

Epoch 64/100 - Train Loss: 294860.2266, Test Loss: 49693.0469
2025-03-26 23:17:09,443 - INFO - Current LR: 1.25e-05


Training Progress:  65%|██████▌   | 65/100 [01:46<00:54,  1.56s/it]

Epoch 65/100 - Train Loss: 276592.5625, Test Loss: 49660.8125
2025-03-26 23:17:10,967 - INFO - Current LR: 6.25e-06


Training Progress:  66%|██████▌   | 66/100 [01:48<00:53,  1.56s/it]

Epoch 66/100 - Train Loss: 278598.5781, Test Loss: 50672.2461
2025-03-26 23:17:12,521 - INFO - Current LR: 6.25e-06


Training Progress:  67%|██████▋   | 67/100 [01:49<00:51,  1.55s/it]

Epoch 67/100 - Train Loss: 282149.4844, Test Loss: 54633.6055
2025-03-26 23:17:14,043 - INFO - Current LR: 6.25e-06


Training Progress:  68%|██████▊   | 68/100 [01:51<00:50,  1.59s/it]

Epoch 68/100 - Train Loss: 312396.2812, Test Loss: 51955.6680
2025-03-26 23:17:15,731 - INFO - Current LR: 6.25e-06


Training Progress:  69%|██████▉   | 69/100 [01:53<00:52,  1.68s/it]

Epoch 69/100 - Train Loss: 308209.4219, Test Loss: 54715.3047
2025-03-26 23:17:17,621 - INFO - Current LR: 6.25e-06


Training Progress:  70%|███████   | 70/100 [01:54<00:49,  1.64s/it]

Epoch 70/100 - Train Loss: 309049.9688, Test Loss: 55254.9688
2025-03-26 23:17:19,179 - INFO - Current LR: 6.25e-06


Training Progress:  71%|███████   | 71/100 [01:56<00:46,  1.61s/it]

Epoch 71/100 - Train Loss: 313564.6406, Test Loss: 54078.9141
2025-03-26 23:17:20,712 - INFO - Current LR: 3.125e-06


Training Progress:  72%|███████▏  | 72/100 [01:57<00:44,  1.59s/it]

Epoch 72/100 - Train Loss: 301101.9688, Test Loss: 54009.6523
2025-03-26 23:17:22,240 - INFO - Current LR: 3.125e-06


Training Progress:  73%|███████▎  | 73/100 [01:59<00:42,  1.58s/it]

Epoch 73/100 - Train Loss: 293205.2969, Test Loss: 52229.6172
2025-03-26 23:17:23,811 - INFO - Current LR: 3.125e-06


Training Progress:  74%|███████▍  | 74/100 [02:00<00:40,  1.57s/it]

Epoch 74/100 - Train Loss: 277999.1562, Test Loss: 54262.2852
2025-03-26 23:17:25,347 - INFO - Current LR: 3.125e-06


Training Progress:  75%|███████▌  | 75/100 [02:02<00:39,  1.56s/it]

Epoch 75/100 - Train Loss: 311264.7656, Test Loss: 52314.5547
2025-03-26 23:17:26,892 - INFO - Current LR: 3.125e-06


Training Progress:  76%|███████▌  | 76/100 [02:04<00:38,  1.62s/it]

Epoch 76/100 - Train Loss: 304974.6875, Test Loss: 50653.0195
2025-03-26 23:17:28,642 - INFO - Current LR: 3.125e-06


Training Progress:  77%|███████▋  | 77/100 [02:06<00:38,  1.69s/it]

Epoch 77/100 - Train Loss: 312952.0781, Test Loss: 51081.2188
2025-03-26 23:17:30,494 - INFO - Current LR: 1.5625e-06


Training Progress:  78%|███████▊  | 78/100 [02:07<00:37,  1.70s/it]

Epoch 78/100 - Train Loss: 303392.9844, Test Loss: 50744.8008
2025-03-26 23:17:32,216 - INFO - Current LR: 1.5625e-06


Training Progress:  79%|███████▉  | 79/100 [02:09<00:34,  1.64s/it]

Epoch 79/100 - Train Loss: 307278.2656, Test Loss: 52775.9219
2025-03-26 23:17:33,734 - INFO - Current LR: 1.5625e-06


Training Progress:  80%|████████  | 80/100 [02:10<00:32,  1.61s/it]

Epoch 80/100 - Train Loss: 309828.3594, Test Loss: 54070.8125
2025-03-26 23:17:35,268 - INFO - Current LR: 1.5625e-06


Training Progress:  81%|████████  | 81/100 [02:12<00:30,  1.58s/it]

Epoch 81/100 - Train Loss: 283504.9531, Test Loss: 51712.2734
2025-03-26 23:17:36,789 - INFO - Current LR: 1.5625e-06


Training Progress:  82%|████████▏ | 82/100 [02:13<00:28,  1.57s/it]

Epoch 82/100 - Train Loss: 291071.1094, Test Loss: 53028.1562
2025-03-26 23:17:38,326 - INFO - Current LR: 1.5625e-06


Training Progress:  83%|████████▎ | 83/100 [02:15<00:26,  1.55s/it]

Epoch 83/100 - Train Loss: 275484.2812, Test Loss: 51804.6016
2025-03-26 23:17:39,837 - INFO - Current LR: 7.8125e-07


Training Progress:  84%|████████▍ | 84/100 [02:17<00:25,  1.59s/it]

Epoch 84/100 - Train Loss: 302440.7969, Test Loss: 52731.1406
2025-03-26 23:17:41,531 - INFO - Current LR: 7.8125e-07


Training Progress:  85%|████████▌ | 85/100 [02:19<00:25,  1.71s/it]

Epoch 85/100 - Train Loss: 270115.7266, Test Loss: 54284.4141
2025-03-26 23:17:43,497 - INFO - Current LR: 7.8125e-07


Training Progress:  86%|████████▌ | 86/100 [02:20<00:24,  1.72s/it]

Epoch 86/100 - Train Loss: 294801.5859, Test Loss: 52863.0664
2025-03-26 23:17:45,236 - INFO - Current LR: 7.8125e-07


Training Progress:  87%|████████▋ | 87/100 [02:22<00:21,  1.67s/it]

Epoch 87/100 - Train Loss: 286157.9531, Test Loss: 54328.2461
2025-03-26 23:17:46,792 - INFO - Current LR: 7.8125e-07


Training Progress:  88%|████████▊ | 88/100 [02:23<00:19,  1.63s/it]

Epoch 88/100 - Train Loss: 336591.3750, Test Loss: 51888.1914
2025-03-26 23:17:48,338 - INFO - Current LR: 7.8125e-07


Training Progress:  89%|████████▉ | 89/100 [02:25<00:17,  1.60s/it]

Epoch 89/100 - Train Loss: 297359.4688, Test Loss: 52610.7383
2025-03-26 23:17:49,866 - INFO - Current LR: 3.90625e-07


Training Progress:  90%|█████████ | 90/100 [02:27<00:15,  1.57s/it]

Epoch 90/100 - Train Loss: 314708.8594, Test Loss: 52950.0156
2025-03-26 23:17:51,379 - INFO - Current LR: 3.90625e-07


Training Progress:  91%|█████████ | 91/100 [02:28<00:14,  1.56s/it]

Epoch 91/100 - Train Loss: 257477.0156, Test Loss: 50051.9609
2025-03-26 23:17:52,907 - INFO - Current LR: 3.90625e-07


Training Progress:  92%|█████████▏| 92/100 [02:30<00:12,  1.56s/it]

Epoch 92/100 - Train Loss: 291971.3438, Test Loss: 53509.1602
2025-03-26 23:17:54,464 - INFO - Current LR: 3.90625e-07


Training Progress:  93%|█████████▎| 93/100 [02:31<00:11,  1.62s/it]

Epoch 93/100 - Train Loss: 306182.2812, Test Loss: 55203.3086
2025-03-26 23:17:56,233 - INFO - Current LR: 3.90625e-07


Training Progress:  94%|█████████▍| 94/100 [02:33<00:10,  1.68s/it]

Epoch 94/100 - Train Loss: 309121.2578, Test Loss: 52478.5352
2025-03-26 23:17:58,056 - INFO - Current LR: 3.90625e-07


Training Progress:  95%|█████████▌| 95/100 [02:35<00:08,  1.64s/it]

Epoch 95/100 - Train Loss: 305576.8125, Test Loss: 52843.9648
2025-03-26 23:17:59,585 - INFO - Current LR: 1.953125e-07


Training Progress:  96%|█████████▌| 96/100 [02:36<00:06,  1.60s/it]

Epoch 96/100 - Train Loss: 307250.9531, Test Loss: 52478.4922
2025-03-26 23:18:01,101 - INFO - Current LR: 1.953125e-07


Training Progress:  97%|█████████▋| 97/100 [02:38<00:04,  1.58s/it]

Epoch 97/100 - Train Loss: 309721.1406, Test Loss: 51938.4688
2025-03-26 23:18:02,630 - INFO - Current LR: 1.953125e-07


Training Progress:  98%|█████████▊| 98/100 [02:39<00:03,  1.57s/it]

Epoch 98/100 - Train Loss: 329342.7031, Test Loss: 51870.1641
2025-03-26 23:18:04,185 - INFO - Current LR: 1.953125e-07


Training Progress:  99%|█████████▉| 99/100 [02:41<00:01,  1.57s/it]

Epoch 99/100 - Train Loss: 298237.8906, Test Loss: 51319.6641
2025-03-26 23:18:05,744 - INFO - Current LR: 1.953125e-07


Training Progress: 100%|██████████| 100/100 [02:42<00:00,  1.63s/it]


Epoch 100/100 - Train Loss: 307444.9062, Test Loss: 49815.0273


In [38]:
'''Extract Latent Representations'''
logging.info("Extracting latent space representations...")
conv_latent, conv_meta = extract_latent_vectors(conv_model, test_loader, adata=adata_filtered, device=device)
logging.info(f"Extracted {len(conv_latent)} latent vectors")

'''Perform Clustering and Visualization'''
logging.info("Performing clustering and visualization...")
conv_ari = perform_clustering_and_visualization(
    latent_vectors=conv_latent,
    metadata=conv_meta,
    output_dir=output_dir,
    model_name="Conventional",
    n_neighbors=n_neighbors,
    metric=metric
)

logging.info(f"\nClustering Results:")
logging.info(f"Conventional Model ARI: {conv_ari:.4f}")

# Clean up at the end
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

logging.info("Analysis complete!")

2025-03-26 23:18:16,677 - INFO - Extracting latent space representations...


Extracting latent vectors: 100%|██████████| 1/1 [00:00<00:00,  4.05it/s]

2025-03-26 23:18:16,932 - INFO - Extracted 60 latent vectors
2025-03-26 23:18:16,933 - INFO - Performing clustering and visualization...





2025-03-26 23:18:17,582 - INFO - 
Clustering Results:
2025-03-26 23:18:17,583 - INFO - Conventional Model ARI: 0.0000
2025-03-26 23:18:18,095 - INFO - Analysis complete!
