<a href="https://colab.research.google.com/github/MichaelMatley/Hierarchial-VAE-Emergent/blob/main/Hierarchical_VAE_Complete.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Hierarchical Multi-Scale Latent VAE

- Theoretical for emergent representation learning
- Stochastic bottleneck forces the model to develop a compressed “language” rather than just memorising pattern

- Complete Colab Notebook: Hierarchical VAE for Emergent Representation Learning

## Part 1: ARHITECRE, TRAINING, & β-ANNEALING.

Cell 1: Environment Setup

In [None]:
# Check hardware
import torch
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Install dependencies
!pip install biopython umap-learn scikit-learn matplotlib seaborn tqdm -q

print("✓ Environment ready")

Cell 2: Import Libraries

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
import umap

from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord

import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)

print("✓ Libraries imported")

Cell 3: DNA Encoding Utilities

In [None]:
class DNAEncoder:
    """
    Convert DNA sequences to numerical representations.
    Supports one-hot encoding with proper handling of ambiguous bases.
    """

    @staticmethod
    def one_hot_encode(sequence):
        """
        One-hot encoding: A=[1,0,0,0], C=[0,1,0,0], G=[0,0,1,0], T=[0,0,0,1]
        Returns: (4, seq_length) array
        """
        mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
        seq_upper = sequence.upper()

        encoded = np.zeros((4, len(seq_upper)), dtype=np.float32)

        for idx, nucleotide in enumerate(seq_upper):
            if nucleotide in mapping:
                encoded[mapping[nucleotide], idx] = 1.0
            # Ambiguous bases (N, etc.) result in all-zero columns

        return encoded

    @staticmethod
    def decode_one_hot(encoded_array):
        """
        Convert one-hot encoded array back to DNA sequence.
        Args:
            encoded_array: (4, seq_length) array
        Returns:
            DNA sequence string
        """
        bases = 'ACGT'
        sequence = ''.join([bases[np.argmax(encoded_array[:, i])]
                           for i in range(encoded_array.shape[1])])
        return sequence


# Test the encoder
test_seq = "ATCGATCGATCGNNNATCG"
encoded = DNAEncoder.one_hot_encode(test_seq)
decoded = DNAEncoder.decode_one_hot(encoded)

print(f"Original:  {test_seq}")
print(f"Decoded:   {decoded}")
print(f"Encoding shape: {encoded.shape}")
print("✓ DNA encoder working")

Cell 4: Dataset Class

In [None]:
class GenomicDataset(Dataset):
    """
    PyTorch Dataset for genomic sequences.
    Extracts fixed-length windows from FASTA files.
    """

    def __init__(self, fasta_file, window_size=1024, stride=512,
                 max_samples=None, filter_n_threshold=0.1):
        """
        Args:
            fasta_file: Path to FASTA file
            window_size: Length of sequence windows
            stride: Sliding window stride
            max_samples: Maximum number of samples to extract (None = all)
            filter_n_threshold: Maximum proportion of N bases allowed
        """
        self.window_size = window_size
        self.stride = stride
        self.sequences = []

        print(f"Loading sequences from {fasta_file}...")

        for record in SeqIO.parse(fasta_file, "fasta"):
            sequence = str(record.seq).upper()

            # Sliding window extraction
            for i in range(0, len(sequence) - window_size + 1, stride):
                if max_samples and len(self.sequences) >= max_samples:
                    break

                chunk = sequence[i:i + window_size]

                # Filter sequences with too many ambiguous bases
                n_count = chunk.count('N')
                if n_count / len(chunk) <= filter_n_threshold:
                    self.sequences.append(chunk)

            if max_samples and len(self.sequences) >= max_samples:
                break

        print(f"✓ Created dataset: {len(self.sequences)} sequences of {window_size} bp")
        print(f"  Overlap: {window_size - stride} bp")

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]

        # One-hot encode
        encoded = DNAEncoder.one_hot_encode(sequence)

        # Flatten: (4, 1024) -> (4096,)
        encoded_flat = encoded.flatten()

        return torch.tensor(encoded_flat, dtype=torch.float32)

    def get_sequence(self, idx):
        """Return raw sequence string"""
        return self.sequences[idx]


def create_synthetic_genome(length=5_000_000, output_file='synthetic_genome.fasta'):
    """
    Generate synthetic genome with realistic base composition.
    C. elegans has ~36% GC content.
    """
    # Weighted base selection (approximating C. elegans)
    bases = ['A', 'T', 'G', 'C']
    weights = [0.32, 0.32, 0.18, 0.18]  # ~36% GC

    sequence = ''.join(np.random.choice(bases, size=length, p=weights))

    record = SeqRecord(
        Seq(sequence),
        id="synthetic_chr",
        description=f"Synthetic {length/1e6:.1f}Mb genome for testing"
    )

    SeqIO.write(record, output_file, "fasta")
    print(f"✓ Created synthetic genome: {output_file} ({length/1e6:.1f} Mb)")

    return output_file


# Create synthetic data for testing
synthetic_file = create_synthetic_genome(length=5_000_000)  # 5 Mb

# Create dataset
dataset = GenomicDataset(
    fasta_file=synthetic_file,
    window_size=1024,
    stride=512,
    max_samples=100_000  # Limit to 100k samples for faster training
)

print(f"\nSample shape: {dataset[0].shape}")
print(f"Sample stats - Min: {dataset[0].min():.2f}, Max: {dataset[0].max():.2f}")

Cell 5: Hierarchical VAE Architecture

In [None]:
class HierarchicalVAE(nn.Module):
    """
    Multi-scale Variational Autoencoder with hierarchical latent spaces.

    Architecture:
        Input (4096) -> Encoder -> 3 latent spaces [256, 512, 1024]
        Latent spaces -> Decoder -> Reconstruction (4096)

    The model learns to represent data at multiple levels of abstraction:
        - Level 1 (256d): Most abstract, compressed representation
        - Level 2 (512d): Intermediate features
        - Level 3 (1024d): Fine-grained details
    """

    def __init__(self, input_dim=4096, latent_dims=[256, 512, 1024], dropout=0.3):
        super().__init__()

        self.input_dim = input_dim
        self.latent_dims = latent_dims

        # ========================
        # ENCODER PATHWAY
        # ========================

        # Stage 1: Input -> 2048
        self.enc1 = nn.Sequential(
            nn.Linear(input_dim, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Stage 2: 2048 -> 1024
        self.enc2 = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Stage 3: 1024 -> 512 (deepest layer)
        self.enc3 = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # ========================
        # LATENT SPACE PROJECTIONS
        # ========================

        # Latent level 1: From deepest layer (most abstract)
        self.z1_mu = nn.Linear(512, latent_dims[0])
        self.z1_logvar = nn.Linear(512, latent_dims[0])

        # Latent level 2: From intermediate layer
        self.z2_mu = nn.Linear(1024, latent_dims[1])
        self.z2_logvar = nn.Linear(1024, latent_dims[1])

        # Latent level 3: From shallow layer (fine details)
        self.z3_mu = nn.Linear(2048, latent_dims[2])
        self.z3_logvar = nn.Linear(2048, latent_dims[2])

        # ========================
        # DECODER PATHWAY
        # ========================

        total_latent_dim = sum(latent_dims)  # 256 + 512 + 1024 = 1792

        # Stage 1: Concatenated latents -> 512
        self.dec1 = nn.Sequential(
            nn.Linear(total_latent_dim, 512),
            nn.LayerNorm(512),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Stage 2: 512 -> 1024
        self.dec2 = nn.Sequential(
            nn.Linear(512, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Stage 3: 1024 -> 2048
        self.dec3 = nn.Sequential(
            nn.Linear(1024, 2048),
            nn.LayerNorm(2048),
            nn.GELU(),
            nn.Dropout(dropout)
        )

        # Output layer: 2048 -> 4096 (reconstruction)
        self.output = nn.Linear(2048, input_dim)

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Xavier initialization for better gradient flow"""
        if isinstance(module, nn.Linear):
            torch.nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def reparameterize(self, mu, logvar):
        """
        Reparameterization trick: z = mu + std * epsilon
        Allows gradients to flow through sampling operation.
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def encode(self, x):
        """
        Encode input into hierarchical latent representations.

        Returns:
            latents: Tuple of (z1, z2, z3) sampled latent vectors
            params: List of (mu, logvar) tuples for KL divergence calculation
        """
        # Forward through encoder stages
        h1 = self.enc1(x)      # (batch, 2048)
        h2 = self.enc2(h1)     # (batch, 1024)
        h3 = self.enc3(h2)     # (batch, 512)

        # Extract latent parameters from each level
        # Level 1: Most abstract (from deepest layer)
        z1_mu = self.z1_mu(h3)
        z1_logvar = self.z1_logvar(h3)
        z1 = self.reparameterize(z1_mu, z1_logvar)

        # Level 2: Intermediate
        z2_mu = self.z2_mu(h2)
        z2_logvar = self.z2_logvar(h2)
        z2 = self.reparameterize(z2_mu, z2_logvar)

        # Level 3: Fine details (from shallowest layer)
        z3_mu = self.z3_mu(h1)
        z3_logvar = self.z3_logvar(h1)
        z3 = self.reparameterize(z3_mu, z3_logvar)

        latents = (z1, z2, z3)
        params = [(z1_mu, z1_logvar), (z2_mu, z2_logvar), (z3_mu, z3_logvar)]

        return latents, params

    def decode(self, latents):
        """
        Decode from hierarchical latent space to reconstruction.

        Args:
            latents: Tuple of (z1, z2, z3)
        """
        # Concatenate all latent levels
        z = torch.cat(latents, dim=-1)  # (batch, 1792)

        # Decode through stages
        h = self.dec1(z)
        h = self.dec2(h)
        h = self.dec3(h)

        return self.output(h)

    def forward(self, x):
        """
        Full forward pass: encode -> sample -> decode

        Returns:
            reconstruction: Reconstructed input
            latents: Sampled latent vectors
            params: Distribution parameters for loss calculation
        """
        latents, params = self.encode(x)
        reconstruction = self.decode(latents)

        return reconstruction, latents, params


# Instantiate model
model = HierarchicalVAE(
    input_dim=4096,
    latent_dims=[256, 512, 1024],
    dropout=0.3
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"✓ Model created")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Latent space dimensions: {model.latent_dims}")
print(f"  Total latent dimension: {sum(model.latent_dims)}")

Cell 6: Loss Functions

In [None]:
def vae_loss_function(recon_x, x, latent_params, beta=1.0, kl_weights=[1.0, 1.0, 1.0]):
    """
    VAE loss = Reconstruction loss + β * KL divergence

    Args:
        recon_x: Reconstructed input
        x: Original input
        latent_params: List of (mu, logvar) tuples for each latent level
        beta: KL divergence weighting factor (β-VAE)
        kl_weights: Per-level KL weights (for hierarchical control)

    Returns:
        total_loss: Combined loss
        recon_loss: Reconstruction term
        kl_loss: KL divergence term
        kl_per_level: KL divergence for each hierarchical level
    """
    # Reconstruction loss (MSE)
    recon_loss = F.mse_loss(recon_x, x, reduction='sum') / x.size(0)

    # KL divergence for each latent level
    kl_per_level = []
    kl_loss = 0

    for idx, (mu, logvar) in enumerate(latent_params):
        # KL(N(mu, sigma) || N(0, 1))
        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=-1)
        kl = kl.mean()  # Average over batch

        kl_per_level.append(kl.item())
        kl_loss += kl_weights[idx] * kl

    # Total loss
    total_loss = recon_loss + beta * kl_loss

    return total_loss, recon_loss, kl_loss, kl_per_level


print("✓ Loss function defined")

Cell 7: Training Loop

In [None]:
def train_hierarchical_vae(model, train_loader, val_loader,
                          epochs=100, lr=1e-3,
                          beta_schedule=None,
                          device='cuda'):
    """
    Training loop with β-annealing and comprehensive monitoring.

    Args:
        beta_schedule: Function that returns beta value given epoch number
                      If None, uses constant β=1.0
    """
    model.to(device)

    # Optimizer with weight decay for regularization
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)

    # Learning rate scheduler: Cosine annealing with warm restarts
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer, T_0=10, T_mult=2, eta_min=1e-6
    )

    # Training history
    history = {
        'train_loss': [], 'train_recon': [], 'train_kl': [],
        'val_loss': [], 'val_recon': [], 'val_kl': [],
        'kl_level1': [], 'kl_level2': [], 'kl_level3': [],
        'beta_values': []
    }

    best_val_loss = float('inf')
    patience_counter = 0
    patience = 15

    print(f"\n{'='*60}")
    print(f"Starting training on {device}")
    print(f"{'='*60}\n")

    for epoch in range(epochs):
        # Determine β value for this epoch
        if beta_schedule is not None:
            beta = beta_schedule(epoch)
        else:
            beta = 1.0

        history['beta_values'].append(beta)

        # ==================
        # TRAINING PHASE
        # ==================
        model.train()
        train_loss = 0
        train_recon = 0
        train_kl = 0
        kl_levels = [0, 0, 0]

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")

        for batch_idx, batch in enumerate(pbar):
            x = batch.to(device)

            optimizer.zero_grad()

            # Forward pass
            recon, latents, params = model(x)

            # Compute loss
            loss, recon_loss, kl_loss, kl_per_level = vae_loss_function(
                recon, x, params, beta=beta
            )

            # Backward pass
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            # Accumulate metrics
            train_loss += loss.item()
            train_recon += recon_loss.item()
            train_kl += kl_loss.item()
            for i in range(3):
                kl_levels[i] += kl_per_level[i]

            # Update progress bar
            pbar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'recon': f'{recon_loss.item():.4f}',
                'kl': f'{kl_loss.item():.4f}',
                'β': f'{beta:.3f}'
            })

        # Average training metrics
        n_train = len(train_loader)
        avg_train_loss = train_loss / n_train
        avg_train_recon = train_recon / n_train
        avg_train_kl = train_kl / n_train
        avg_kl_levels = [kl / n_train for kl in kl_levels]

        # ==================
        # VALIDATION PHASE
        # ==================
        model.eval()
        val_loss = 0
        val_recon = 0
        val_kl = 0

        with torch.no_grad():
            for batch in val_loader:
                x = batch.to(device)

                recon, latents, params = model(x)
                loss, recon_loss, kl_loss, _ = vae_loss_function(
                    recon, x, params, beta=beta
                )

                val_loss += loss.item()
                val_recon += recon_loss.item()
                val_kl += kl_loss.item()

        # Average validation metrics
        n_val = len(val_loader)
        avg_val_loss = val_loss / n_val
        avg_val_recon = val_recon / n_val
        avg_val_kl = val_kl / n_val

        # Update history
        history['train_loss'].append(avg_train_loss)
        history['train_recon'].append(avg_train_recon)
        history['train_kl'].append(avg_train_kl)
        history['val_loss'].append(avg_val_loss)
        history['val_recon'].append(avg_val_recon)
        history['val_kl'].append(avg_val_kl)
        history['kl_level1'].append(avg_kl_levels[0])
        history['kl_level2'].append(avg_kl_levels[1])
        history['kl_level3'].append(avg_kl_levels[2])

        # Learning rate scheduling
        scheduler.step()
        current_lr = optimizer.param_groups[0]['lr']

        # Print epoch summary
        print(f"\nEpoch {epoch+1}/{epochs}")
        print(f"  Train Loss: {avg_train_loss:.4f} | Recon: {avg_train_recon:.4f} | KL: {avg_train_kl:.4f}")
        print(f"  Val Loss:   {avg_val_loss:.4f} | Recon: {avg_val_recon:.4f} | KL: {avg_val_kl:.4f}")
        print(f"  KL Levels:  L1={avg_kl_levels[0]:.2f} | L2={avg_kl_levels[1]:.2f} | L3={avg_kl_levels[2]:.2f}")
        print(f"  LR: {current_lr:.2e} | β: {beta:.3f}")

        # Early stopping check
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            patience_counter = 0

            # Save best model
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'train_loss': avg_train_loss,
                'val_loss': avg_val_loss,
                'history': history
            }, 'best_hierarchical_vae.pth')

            print(f"  ✓ Best model saved (val_loss: {best_val_loss:.4f})")
        else:
            patience_counter += 1
            print(f"  Patience: {patience_counter}/{patience}")

        if patience_counter >= patience:
            print(f"\n{'='*60}")
            print(f"Early stopping triggered at epoch {epoch+1}")
            print(f"{'='*60}\n")
            break

    print("\n✓ Training complete")
    return history


print("✓ Training function defined")

Cell 8: β-Annealing Schedule

In [None]:
def beta_annealing_schedule(epoch, warmup_epochs=20, max_beta=1.0, mode='linear'):
    """
    β-annealing for VAE training.
    Start with β=0 (pure autoencoder) and gradually increase.

    Modes:
        'linear': Linear increase from 0 to max_beta
        'cyclical': Cyclical annealing (multiple cycles)
        'constant': No annealing, use max_beta throughout
    """
    if mode == 'constant':
        return max_beta

    elif mode == 'linear':
        if epoch < warmup_epochs:
            return (epoch / warmup_epochs) * max_beta
        return max_beta

    elif mode == 'cyclical':
        cycle_length = 20
        cycle_progress = (epoch % cycle_length) / cycle_length
        return cycle_progress * max_beta

    return max_beta


# Test the schedule
epochs_test = 50
betas = [beta_annealing_schedule(e, warmup_epochs=20, mode='linear') for e in range(epochs_test)]

plt.figure(figsize=(10, 4))
plt.plot(betas, linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('β value')
plt.title('β-Annealing Schedule (Linear Warmup)')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print("✓ β-annealing schedule defined")

Cell 9: Create Data Loaders and Train

In [None]:
# Split dataset
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

train_dataset, val_dataset, test_dataset = random_split(
    dataset,
    [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(42)
)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True,
                         num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False,
                       num_workers=2, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False,
                        num_workers=2, pin_memory=True)

print(f"Dataset splits:")
print(f"  Train: {len(train_dataset):,} samples")
print(f"  Val:   {len(val_dataset):,} samples")
print(f"  Test:  {len(test_dataset):,} samples")
print(f"  Batch size: {batch_size}")
print(f"  Batches per epoch: {len(train_loader):,}\n")

# Train the model
device = 'cuda' if torch.cuda.is_available() else 'cpu'

history = train_hierarchical_vae(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    epochs=100,
    lr=1e-3,
    beta_schedule=lambda epoch: beta_annealing_schedule(epoch, warmup_epochs=20, mode='linear'),
    device=device
)

## Part 2: Latent Space Analysis & Visualization section

Cell 10: Load Best Model

In [None]:
# Load the best checkpoint
checkpoint = torch.load('best_hierarchical_vae.pth')
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()

print(f"✓ Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"  Final validation loss: {checkpoint['val_loss']:.4f}")

Cell 11: Training History Visualization

In [None]:
def plot_training_history(history):
    """
    Comprehensive visualization of training dynamics.
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Plot 1: Total Loss
    ax = axes[0, 0]
    ax.plot(history['train_loss'], label='Train Loss', linewidth=2, alpha=0.8)
    ax.plot(history['val_loss'], label='Val Loss', linewidth=2, alpha=0.8)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Total Loss (Reconstruction + KL)', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    # Plot 2: Reconstruction Loss
    ax = axes[0, 1]
    ax.plot(history['train_recon'], label='Train Recon', linewidth=2, alpha=0.8)
    ax.plot(history['val_recon'], label='Val Recon', linewidth=2, alpha=0.8)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Reconstruction Loss')
    ax.set_title('Reconstruction Loss (MSE)', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    # Plot 3: KL Divergence
    ax = axes[1, 0]
    ax.plot(history['train_kl'], label='Train KL', linewidth=2, alpha=0.8, color='crimson')
    ax.plot(history['val_kl'], label='Val KL', linewidth=2, alpha=0.8, color='darkred')
    ax.set_xlabel('Epoch')
    ax.set_ylabel('KL Divergence')
    ax.set_title('KL Divergence', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    # Plot 4: Hierarchical KL Levels
    ax = axes[1, 1]
    ax.plot(history['kl_level1'], label='Level 1 (256d)', linewidth=2, alpha=0.8)
    ax.plot(history['kl_level2'], label='Level 2 (512d)', linewidth=2, alpha=0.8)
    ax.plot(history['kl_level3'], label='Level 3 (1024d)', linewidth=2, alpha=0.8)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('KL Divergence')
    ax.set_title('KL Divergence by Hierarchical Level', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(alpha=0.3)

    plt.tight_layout()
    plt.savefig('training_history.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Training history visualized")


plot_training_history(history)

Cell 12: Extract Latent Representations

In [None]:
def extract_latent_representations(model, dataloader, device, max_samples=10000):
    """
    Extract all three hierarchical latent levels from the model.

    Returns:
        latents_dict: Dictionary with keys 'level1', 'level2', 'level3'
        Each contains numpy array of shape (num_samples, latent_dim)
    """
    model.eval()

    latents_l1 = []
    latents_l2 = []
    latents_l3 = []

    samples_processed = 0

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Extracting latents"):
            if samples_processed >= max_samples:
                break

            x = batch.to(device)

            # Get latent representations
            _, latents, _ = model(x)
            z1, z2, z3 = latents

            latents_l1.append(z1.cpu().numpy())
            latents_l2.append(z2.cpu().numpy())
            latents_l3.append(z3.cpu().numpy())

            samples_processed += len(x)

    # Concatenate all batches
    latents_dict = {
        'level1': np.concatenate(latents_l1, axis=0)[:max_samples],
        'level2': np.concatenate(latents_l2, axis=0)[:max_samples],
        'level3': np.concatenate(latents_l3, axis=0)[:max_samples]
    }

    print(f"\n✓ Extracted latent representations:")
    for level, latents in latents_dict.items():
        print(f"  {level}: {latents.shape}")

    return latents_dict


# Extract latents from test set
latents_dict = extract_latent_representations(model, test_loader, device, max_samples=10000)

Cell 13: Intrinsic Dimensionality Analysis

In [None]:
def analyze_intrinsic_dimensionality(latents_dict):
    """
    Measure the intrinsic dimensionality of each latent level using PCA.

    Intrinsic dimensionality = number of components needed to explain 95% variance.
    This tells us how much the model actually uses its latent capacity.
    """
    results = {}

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (level_name, latents) in enumerate(latents_dict.items()):
        # Fit PCA
        pca = PCA()
        pca.fit(latents)

        # Calculate cumulative explained variance
        cumsum_variance = np.cumsum(pca.explained_variance_ratio_)

        # Find intrinsic dimensionality (95% threshold)
        intrinsic_dim = np.argmax(cumsum_variance >= 0.95) + 1

        # Store results
        results[level_name] = {
            'nominal_dim': latents.shape[1],
            'intrinsic_dim': intrinsic_dim,
            'explained_variance_ratio': pca.explained_variance_ratio_,
            'cumsum_variance': cumsum_variance
        }

        # Plot
        ax = axes[idx]
        ax.plot(cumsum_variance, linewidth=2.5, color='darkblue')
        ax.axhline(y=0.95, color='red', linestyle='--', linewidth=2, alpha=0.7, label='95% threshold')
        ax.axvline(x=intrinsic_dim, color='green', linestyle='--', linewidth=2, alpha=0.7,
                   label=f'Intrinsic dim: {intrinsic_dim}')
        ax.set_xlabel('Number of Components', fontsize=11)
        ax.set_ylabel('Cumulative Explained Variance', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} ({latents.shape[1]}d)',
                     fontsize=12, fontweight='bold')
        ax.legend(fontsize=10)
        ax.grid(alpha=0.3)
        ax.set_ylim([0, 1.05])

    plt.tight_layout()
    plt.savefig('intrinsic_dimensionality.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print summary
    print("\n" + "="*60)
    print("INTRINSIC DIMENSIONALITY ANALYSIS")
    print("="*60)
    for level_name, result in results.items():
        utilization = (result['intrinsic_dim'] / result['nominal_dim']) * 100
        print(f"\n{level_name.upper()}:")
        print(f"  Nominal dimension:    {result['nominal_dim']}")
        print(f"  Intrinsic dimension:  {result['intrinsic_dim']}")
        print(f"  Utilization:          {utilization:.1f}%")
    print("="*60)

    return results


intrinsic_dims = analyze_intrinsic_dimensionality(latents_dict)

Cell 14: Latent Space Visualization with UMAP

In [None]:
def visualize_latent_space_umap(latents_dict, n_samples=5000):
    """
    Visualize all three latent levels using UMAP dimensionality reduction.
    UMAP preserves both local and global structure.
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (level_name, latents) in enumerate(latents_dict.items()):
        # Subsample for faster computation
        if len(latents) > n_samples:
            indices = np.random.choice(len(latents), n_samples, replace=False)
            latents_subset = latents[indices]
        else:
            latents_subset = latents

        print(f"Computing UMAP for {level_name}...")

        # Fit UMAP
        reducer = umap.UMAP(
            n_components=2,
            n_neighbors=15,
            min_dist=0.1,
            metric='euclidean',
            random_state=42
        )
        embedding = reducer.fit_transform(latents_subset)

        # Plot
        ax = axes[idx]
        scatter = ax.scatter(
            embedding[:, 0],
            embedding[:, 1],
            c=np.arange(len(embedding)),  # Color by index (temporal ordering)
            cmap='viridis',
            s=10,
            alpha=0.6,
            rasterized=True
        )

        ax.set_xlabel('UMAP 1', fontsize=11)
        ax.set_ylabel('UMAP 2', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} ({latents.shape[1]}d → 2d)',
                     fontsize=12, fontweight='bold')
        ax.set_xticks([])
        ax.set_yticks([])

        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Sample Index', fontsize=10)

    plt.tight_layout()
    plt.savefig('latent_space_umap.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ UMAP visualization complete")


visualize_latent_space_umap(latents_dict, n_samples=5000)

Cell 15: t-SNE Visualization (Alternative)

In [None]:
visualize_latent_space_tsne(latents_dict, n_samples=3000):
    """
    Visualize using t-SNE (preserves local structure better than global).
    """
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (level_name, latents) in enumerate(latents_dict.items()):
        # Subsample for faster computation
        if len(latents) > n_samples:
            indices = np.random.choice(len(latents), n_samples, replace=False)
            latents_subset = latents[indices]
        else:
            latents_subset = latents

        print(f"Computing t-SNE for {level_name}...")

        # Fit t-SNE
        tsne = TSNE(
            n_components=2,
            perplexity=30,
            n_iter=1000,
            random_state=42
        )
        embedding = tsne.fit_transform(latents_subset)

        # Plot
        ax = axes[idx]
        scatter = ax.scatter(
            embedding[:, 0],
            embedding[:, 1],
            c=np.arange(len(embedding)),
            cmap='plasma',
            s=10,
            alpha=0.6,
            rasterized=True
        )

        ax.set_xlabel('t-SNE 1', fontsize=11)
        ax.set_ylabel('t-SNE 2', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} t-SNE',
                     fontsize=12, fontweight='bold')
        ax.set_xticks([])
        ax.set_yticks([])

        cbar = plt.colorbar(scatter, ax=ax)
        cbar.set_label('Sample Index', fontsize=10)

    plt.tight_layout()
    plt.savefig('latent_space_tsne.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ t-SNE visualization complete")


visualize_latent_space_tsne(latents_dict, n_samples=3000)

Cell 16: Reconstruction Quality Assessment

In [None]:
def evaluate_reconstruction_quality(model, dataloader, device, num_samples=10):
    """
    Evaluate how well the model reconstructs sequences.
    Shows per-nucleotide accuracy and visualizes differences.
    """
    model.eval()

    samples_shown = 0
    all_accuracies = []

    fig, axes = plt.subplots(num_samples, 1, figsize=(16, num_samples * 1.5))
    if num_samples == 1:
        axes = [axes]

    with torch.no_grad():
        for batch in dataloader:
            if samples_shown >= num_samples:
                break

            x = batch.to(device)

            # Get reconstruction
            recon, _, _ = model(x)

            for i in range(min(len(x), num_samples - samples_shown)):
                # Convert to numpy
                original = x[i].cpu().numpy().reshape(4, 1024)
                reconstructed = recon[i].cpu().numpy().reshape(4, 1024)

                # Decode to sequences
                orig_seq = DNAEncoder.decode_one_hot(original)
                recon_seq = DNAEncoder.decode_one_hot(reconstructed)

                # Calculate per-base accuracy
                matches = [1 if o == r else 0 for o, r in zip(orig_seq, recon_seq)]
                accuracy = sum(matches) / len(matches)
                all_accuracies.append(accuracy)

                # Visualize alignment
                ax = axes[samples_shown]

                # Show first 100 bases
                display_length = 100
                orig_display = orig_seq[:display_length]
                recon_display = recon_seq[:display_length]
                matches_display = matches[:display_length]

                # Create visualization
                colors = ['red' if m == 0 else 'green' for m in matches_display]

                for pos, (o, r, color) in enumerate(zip(orig_display, recon_display, colors)):
                    ax.text(pos, 1, o, ha='center', va='center', fontsize=8,
                           family='monospace', color='black')
                    ax.text(pos, 0, r, ha='center', va='center', fontsize=8,
                           family='monospace', color=color, fontweight='bold')

                ax.set_xlim(-1, display_length)
                ax.set_ylim(-0.5, 1.5)
                ax.set_yticks([0, 1])
                ax.set_yticklabels(['Recon', 'Original'], fontsize=9)
                ax.set_title(f'Sample {samples_shown+1} | Accuracy: {accuracy:.2%} | '
                           f'Errors: {sum(1 for m in matches if m == 0)}/{len(matches)}',
                           fontsize=10, fontweight='bold', loc='left')
                ax.set_xticks([])
                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['bottom'].set_visible(False)

                samples_shown += 1

                if samples_shown >= num_samples:
                    break

    plt.tight_layout()
    plt.savefig('reconstruction_quality.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print statistics
    print("\n" + "="*60)
    print("RECONSTRUCTION QUALITY STATISTICS")
    print("="*60)
    print(f"Mean accuracy:   {np.mean(all_accuracies):.4f}")
    print(f"Median accuracy: {np.median(all_accuracies):.4f}")
    print(f"Std deviation:   {np.std(all_accuracies):.4f}")
    print(f"Min accuracy:    {np.min(all_accuracies):.4f}")
    print(f"Max accuracy:    {np.max(all_accuracies):.4f}")
    print("="*60)

    return all_accuracies


reconstruction_accuracies = evaluate_reconstruction_quality(
    model, test_loader, device, num_samples=10
)

Cell 17: Latent Space Interpolation

In [None]:
def interpolate_latent_space(model, dataloader, device, num_steps=10):
    """
    Interpolate between two random points in latent space.
    Shows what the model has learned about smooth transitions.
    """
    model.eval()

    # Get two random samples
    batch = next(iter(dataloader)).to(device)
    x1, x2 = batch[0:1], batch[1:2]

    with torch.no_grad():
        # Encode to latent space
        latents1, _ = model.encode(x1)
        latents2, _ = model.encode(x2)

        # Interpolate at each hierarchical level
        interpolations = []

        for alpha in np.linspace(0, 1, num_steps):
            interp_latents = tuple(
                (1 - alpha) * z1 + alpha * z2
                for z1, z2 in zip(latents1, latents2)
            )

            # Decode
            recon = model.decode(interp_latents)
            interpolations.append(recon.cpu().numpy())

        interpolations = np.array(interpolations)

    # Visualize
    fig, axes = plt.subplots(num_steps, 1, figsize=(16, num_steps * 1))

    for idx, interp in enumerate(interpolations):
        # Decode to sequence
        interp_reshaped = interp[0].reshape(4, 1024)
        seq = DNAEncoder.decode_one_hot(interp_reshaped)

        # Show first 100 bases
        display_seq = seq[:100]

        ax = axes[idx]
        for pos, base in enumerate(display_seq):
            color_map = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'}
            ax.text(pos, 0, base, ha='center', va='center', fontsize=8,
                   family='monospace', color=color_map.get(base, 'black'))

        ax.set_xlim(-1, len(display_seq))
        ax.set_ylim(-0.5, 0.5)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_title(f'Step {idx+1}/{num_steps} (α={idx/(num_steps-1):.2f})',
                    fontsize=9, loc='left')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

    plt.tight_layout()
    plt.savefig('latent_interpolation.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Latent space interpolation visualized")


interpolate_latent_space(model, test_loader, device, num_steps=10)

Cell 18: Clustering Analysis

In [None]:
def analyze_latent_clustering(latents_dict, n_clusters=10):
    """
    Perform k-means clustering on each latent level.
    Measures how well the model self-organizes data.
    """
    from sklearn.cluster import KMeans
    from sklearn.metrics import silhouette_score, davies_bouldin_score

    results = {}

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (level_name, latents) in enumerate(latents_dict.items()):
        print(f"Clustering {level_name}...")

        # Perform k-means
        kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
        cluster_labels = kmeans.fit_predict(latents)

        # Compute metrics
        silhouette = silhouette_score(latents, cluster_labels)
        davies_bouldin = davies_bouldin_score(latents, cluster_labels)
        inertia = kmeans.inertia_

        results[level_name] = {
            'silhouette': silhouette,
            'davies_bouldin': davies_bouldin,
            'inertia': inertia,
            'cluster_labels': cluster_labels
        }

        # Visualize cluster distribution
        ax = axes[idx]
        unique, counts = np.unique(cluster_labels, return_counts=True)
        ax.bar(unique, counts, color='steelblue', alpha=0.8)
        ax.set_xlabel('Cluster ID', fontsize=11)
        ax.set_ylabel('Number of Samples', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} Clustering\n'
                    f'Silhouette: {silhouette:.3f} | DB: {davies_bouldin:.3f}',
                    fontsize=11, fontweight='bold')
        ax.grid(alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('clustering_analysis.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print summary
    print("\n" + "="*60)
    print("CLUSTERING ANALYSIS")
    print("="*60)
    for level_name, result in results.items():
        print(f"\n{level_name.upper()}:")
        print(f"  Silhouette score:     {result['silhouette']:.4f} (higher is better, range [-1, 1])")
        print(f"  Davies-Bouldin score: {result['davies_bouldin']:.4f} (lower is better)")
        print(f"  Inertia:              {result['inertia']:.2f}")
    print("="*60)

    return results


clustering_results = analyze_latent_clustering(latents_dict, n_clusters=10)

Cell 19: Latent Space Arithmetic

In [None]:
def latent_arithmetic(model, dataloader, device):
    """
    Test if latent space supports meaningful vector arithmetic.
    Similar to word2vec's "king - man + woman = queen"
    """
    model.eval()

    # Get three random samples
    batch = next(iter(dataloader)).to(device)
    x1, x2, x3 = batch[0:1], batch[1:2], batch[2:3]

    with torch.no_grad():
        # Encode
        latents1, _ = model.encode(x1)
        latents2, _ = model.encode(x2)
        latents3, _ = model.encode(x3)

        # Perform arithmetic: (x1 - x2) + x3
        result_latents = tuple(
            (z1 - z2) + z3
            for z1, z2, z3 in zip(latents1, latents2, latents3)
        )

        # Decode all
        recon1 = model.decode(latents1).cpu().numpy()[0].reshape(4, 1024)
        recon2 = model.decode(latents2).cpu().numpy()[0].reshape(4, 1024)
        recon3 = model.decode(latents3).cpu().numpy()[0].reshape(4, 1024)
        recon_result = model.decode(result_latents).cpu().numpy()[0].reshape(4, 1024)

        # Decode to sequences
        seq1 = DNAEncoder.decode_one_hot(recon1)
        seq2 = DNAEncoder.decode_one_hot(recon2)
        seq3 = DNAEncoder.decode_one_hot(recon3)
        seq_result = DNAEncoder.decode_one_hot(recon_result)

    # Visualize
    fig, axes = plt.subplots(4, 1, figsize=(16, 8))
    sequences = [seq1, seq2, seq3, seq_result]
    labels = ['Sequence A', 'Sequence B', 'Sequence C', 'Result: (A - B) + C']

    for ax, seq, label in zip(axes, sequences, labels):
        display_seq = seq[:100]

        for pos, base in enumerate(display_seq):
            color_map = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'}
            ax.text(pos, 0, base, ha='center', va='center', fontsize=9,
                   family='monospace', color=color_map.get(base, 'black'),
                   fontweight='bold')

        ax.set_xlim(-1, len(display_seq))
        ax.set_ylim(-0.5, 0.5)
        ax.set_yticks([])
        ax.set_xticks([])
        ax.set_ylabel(label, fontsize=10, fontweight='bold')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)
        ax.spines['left'].set_visible(False)

    plt.tight_layout()
    plt.savefig('latent_arithmetic.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Latent arithmetic visualization complete")


latent_arithmetic(model, test_loader, device)

Cell 20: Save Complete Analysis Report

In [None]:
def generate_analysis_report(history, intrinsic_dims, clustering_results,
                            reconstruction_accuracies):
    """
    Generate a comprehensive text report of all analyses.
    """
    report = []
    report.append("="*80)
    report.append("HIERARCHICAL VAE ANALYSIS REPORT")
    report.append("="*80)
    report.append("")

    # Training summary
    report.append("1. TRAINING SUMMARY")
    report.append("-" * 80)
    report.append(f"   Final training loss:     {history['train_loss'][-1]:.4f}")
    report.append(f"   Final validation loss:   {history['val_loss'][-1]:.4f}")
    report.append(f"   Final reconstruction:    {history['train_recon'][-1]:.4f}")
    report.append(f"   Final KL divergence:     {history['train_kl'][-1]:.4f}")
    report.append(f"   Total epochs:            {len(history['train_loss'])}")
    report.append("")

    # Intrinsic dimensionality
    report.append("2. INTRINSIC DIMENSIONALITY")
    report.append("-" * 80)
    for level, dims in intrinsic_dims.items():
        utilization = (dims['intrinsic_dim'] / dims['nominal_dim']) * 100
        report.append(f"   {level}:")
        report.append(f"     Nominal:    {dims['nominal_dim']}")
        report.append(f"     Intrinsic:  {dims['intrinsic_dim']}")
        report.append(f"     Usage:      {utilization:.1f}%")
    report.append("")

    # Clustering quality
    report.append("3. CLUSTERING QUALITY")
    report.append("-" * 80)
    for level, results in clustering_results.items():
        report.append(f"   {level}:")
        report.append(f"     Silhouette:     {results['silhouette']:.4f}")
        report.append(f"     Davies-Bouldin: {results['davies_bouldin']:.4f}")
    report.append("")

    # Reconstruction quality
    report.append("4. RECONSTRUCTION QUALITY")
    report.append("-" * 80)
    report.append(f"   Mean accuracy:   {np.mean(reconstruction_accuracies):.4f}")
    report.append(f"   Median accuracy: {np.median(reconstruction_accuracies):.4f}")
    report.append(f"   Std deviation:   {np.std(reconstruction_accuracies):.4f}")
    report.append("")

    report.append("="*80)
    report.append("END OF REPORT")
    report.append("="*80)

    # Save to file
    report_text = "\n".join(report)
    with open('analysis_report.txt', 'w') as f:
        f.write(report_text)

    print(report_text)
    print("\n✓ Report saved to 'analysis_report.txt'")

    return report_text


report = generate_analysis_report(
    history,
    intrinsic_dims,
    clustering_results,
    reconstruction_accuracies
)

Cell 21: Download All Artifacts

In [None]:
# Download all saved files
from google.colab import files

print("Downloading artifacts...")

artifacts = [
    'best_hierarchical_vae.pth',
    'training_history.png',
    'intrinsic_dimensionality.png',
    'latent_space_umap.png',
    'latent_space_tsne.png',
    'reconstruction_quality.png',
    'latent_interpolation.png',
    'clustering_analysis.png',
    'latent_arithmetic.png',
    'analysis_report.txt'
]

for artifact in artifacts:
    try:
        files.download(artifact)
        print(f"✓ Downloaded: {artifact}")
    except:
        print(f"✗ Could not download: {artifact}")

print("\n✓ Download complete")

## Part 3: Interpretability & Emergent Structure Analysis

Cell 22: Latent Activation Patterns

In [None]:
def analyze_latent_activation_patterns(model, dataloader, device, num_samples=1000):
    """
    Analyze which latent dimensions are most active.
    Dead neurons indicate capacity waste; overactive ones indicate bottlenecks.
    """
    model.eval()

    # Collect activations
    activations_l1 = []
    activations_l2 = []
    activations_l3 = []

    samples_collected = 0

    with torch.no_grad():
        for batch in dataloader:
            if samples_collected >= num_samples:
                break

            x = batch.to(device)
            latents, _ = model.encode(x)

            activations_l1.append(latents[0].cpu().numpy())
            activations_l2.append(latents[1].cpu().numpy())
            activations_l3.append(latents[2].cpu().numpy())

            samples_collected += len(x)

    # Concatenate
    act_l1 = np.concatenate(activations_l1, axis=0)[:num_samples]
    act_l2 = np.concatenate(activations_l2, axis=0)[:num_samples]
    act_l3 = np.concatenate(activations_l3, axis=0)[:num_samples]

    activations = {
        'level1': act_l1,
        'level2': act_l2,
        'level3': act_l3
    }

    # Analyze activation statistics
    fig, axes = plt.subplots(3, 2, figsize=(16, 12))

    for idx, (level_name, act) in enumerate(activations.items()):
        # Mean activation per dimension
        mean_act = np.mean(np.abs(act), axis=0)
        std_act = np.std(act, axis=0)

        # Plot mean activations
        ax = axes[idx, 0]
        ax.bar(range(len(mean_act)), mean_act, color='steelblue', alpha=0.7)
        ax.set_xlabel('Latent Dimension', fontsize=11)
        ax.set_ylabel('Mean |Activation|', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} - Mean Activation per Dimension',
                    fontsize=11, fontweight='bold')
        ax.grid(alpha=0.3, axis='y')

        # Identify dead/underutilized neurons (threshold: mean < 0.01)
        dead_threshold = 0.01
        dead_neurons = np.sum(mean_act < dead_threshold)
        ax.axhline(y=dead_threshold, color='red', linestyle='--', linewidth=2,
                  alpha=0.7, label=f'Dead neurons: {dead_neurons}')
        ax.legend()

        # Plot activation distribution (heatmap)
        ax = axes[idx, 1]

        # Subsample for visualization
        subsample_idx = np.random.choice(len(act), min(500, len(act)), replace=False)
        act_subset = act[subsample_idx]

        im = ax.imshow(act_subset.T, aspect='auto', cmap='RdBu_r',
                      interpolation='nearest', vmin=-3, vmax=3)
        ax.set_xlabel('Sample Index', fontsize=11)
        ax.set_ylabel('Latent Dimension', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} - Activation Heatmap',
                    fontsize=11, fontweight='bold')

        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Activation Value', fontsize=10)

    plt.tight_layout()
    plt.savefig('activation_patterns.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print statistics
    print("\n" + "="*60)
    print("LATENT ACTIVATION ANALYSIS")
    print("="*60)

    for level_name, act in activations.items():
        mean_act = np.mean(np.abs(act), axis=0)
        dead_neurons = np.sum(mean_act < 0.01)
        utilization = (1 - dead_neurons / len(mean_act)) * 100

        print(f"\n{level_name.upper()}:")
        print(f"  Total dimensions:      {act.shape[1]}")
        print(f"  Dead neurons (<0.01):  {dead_neurons}")
        print(f"  Active utilization:    {utilization:.1f}%")
        print(f"  Mean activation:       {np.mean(mean_act):.4f}")
        print(f"  Std activation:        {np.std(mean_act):.4f}")

    print("="*60)

    return activations


activation_patterns = analyze_latent_activation_patterns(
    model, test_loader, device, num_samples=1000
)

Cell 23: Latent Dimension Importance Ranking

In [None]:
def rank_latent_dimensions_by_importance(model, dataloader, device, num_samples=500):
    """
    Ablation study: which latent dimensions matter most for reconstruction?
    Systematically zero out each dimension and measure reconstruction error increase.
    """
    model.eval()

    # Get baseline reconstruction error
    print("Computing baseline reconstruction error...")
    baseline_errors = []

    with torch.no_grad():
        for idx, batch in enumerate(dataloader):
            if idx * batch.size(0) >= num_samples:
                break

            x = batch.to(device)
            recon, _, _ = model(x)
            error = F.mse_loss(recon, x, reduction='none').mean(dim=1)
            baseline_errors.append(error.cpu().numpy())

    baseline_errors = np.concatenate(baseline_errors)[:num_samples]
    baseline_mean = np.mean(baseline_errors)

    print(f"Baseline error: {baseline_mean:.6f}")

    # Test each dimension in level 1 (most abstract)
    print("\nTesting dimension importance (Level 1 only, for speed)...")

    latent_dim = 256  # Level 1 dimension
    importance_scores = []

    for dim_idx in tqdm(range(latent_dim), desc="Ablating dimensions"):
        ablation_errors = []

        with torch.no_grad():
            for idx, batch in enumerate(dataloader):
                if idx * batch.size(0) >= num_samples:
                    break

                x = batch.to(device)

                # Encode
                latents, _ = model.encode(x)

                # Ablate specific dimension in level 1
                z1_ablated = latents[0].clone()
                z1_ablated[:, dim_idx] = 0

                latents_ablated = (z1_ablated, latents[1], latents[2])

                # Decode
                recon = model.decode(latents_ablated)

                # Measure error
                error = F.mse_loss(recon, x, reduction='none').mean(dim=1)
                ablation_errors.append(error.cpu().numpy())

        ablation_errors = np.concatenate(ablation_errors)[:num_samples]
        ablation_mean = np.mean(ablation_errors)

        # Importance = increase in error when dimension is removed
        importance = ablation_mean - baseline_mean
        importance_scores.append(importance)

    importance_scores = np.array(importance_scores)

    # Rank dimensions
    ranked_indices = np.argsort(importance_scores)[::-1]  # Descending

    # Visualize
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))

    # Plot 1: Importance scores
    ax = axes[0]
    ax.bar(range(len(importance_scores)), importance_scores[ranked_indices],
           color='crimson', alpha=0.7)
    ax.set_xlabel('Dimension (sorted by importance)', fontsize=11)
    ax.set_ylabel('Importance Score (Δ Error)', fontsize=11)
    ax.set_title('Latent Dimension Importance (Level 1)',
                fontsize=12, fontweight='bold')
    ax.grid(alpha=0.3, axis='y')

    # Plot 2: Top 20 dimensions
    ax = axes[1]
    top_k = 20
    top_dims = ranked_indices[:top_k]
    top_scores = importance_scores[top_dims]

    ax.barh(range(top_k), top_scores, color='darkgreen', alpha=0.7)
    ax.set_yticks(range(top_k))
    ax.set_yticklabels([f'Dim {d}' for d in top_dims], fontsize=9)
    ax.set_xlabel('Importance Score', fontsize=11)
    ax.set_title(f'Top {top_k} Most Important Dimensions',
                fontsize=12, fontweight='bold')
    ax.grid(alpha=0.3, axis='x')
    ax.invert_yaxis()

    plt.tight_layout()
    plt.savefig('dimension_importance.png', dpi=150, bbox_inches='tight')
    plt.show()

    # Print top dimensions
    print("\n" + "="*60)
    print("TOP 10 MOST IMPORTANT DIMENSIONS (LEVEL 1)")
    print("="*60)
    for rank, dim_idx in enumerate(ranked_indices[:10], 1):
        print(f"{rank:2d}. Dimension {dim_idx:3d} | Importance: {importance_scores[dim_idx]:.6f}")
    print("="*60)

    return importance_scores, ranked_indices


importance_scores, ranked_dims = rank_latent_dimensions_by_importance(
    model, test_loader, device, num_samples=500
)

Cell 24: Directional Latent Space Exploration

In [None]:
def explore_latent_directions(model, dataloader, device, num_directions=5):
    """
    Find interpretable directions in latent space.
    Move along principal components to see what changes.
    """
    model.eval()

    # Extract latents for PCA
    print("Extracting latents for PCA...")
    latents_l1 = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Encoding"):
            x = batch.to(device)
            latents, _ = model.encode(x)
            latents_l1.append(latents[0].cpu().numpy())

    latents_l1 = np.concatenate(latents_l1, axis=0)

    # Fit PCA to find principal directions
    print("Computing principal directions...")
    pca = PCA(n_components=num_directions)
    pca.fit(latents_l1)

    # Get a base sample
    base_sample = next(iter(dataloader))[0:1].to(device)

    with torch.no_grad():
        base_latents, _ = model.encode(base_sample)
        base_z1 = base_latents[0]

    # Explore each principal direction
    fig, axes = plt.subplots(num_directions, 1, figsize=(16, num_directions * 1.5))

    for dir_idx in range(num_directions):
        direction = torch.tensor(pca.components_[dir_idx],
                                device=device, dtype=torch.float32)

        # Generate samples along this direction
        alphas = np.linspace(-3, 3, 7)  # -3σ to +3σ
        sequences = []

        with torch.no_grad():
            for alpha in alphas:
                # Move along direction
                z1_modified = base_z1 + alpha * direction.unsqueeze(0)

                # Keep other levels unchanged
                modified_latents = (z1_modified, base_latents[1], base_latents[2])

                # Decode
                recon = model.decode(modified_latents)
                recon_reshaped = recon[0].cpu().numpy().reshape(4, 1024)
                seq = DNAEncoder.decode_one_hot(recon_reshaped)
                sequences.append(seq[:80])  # First 80 bases

        # Visualize this direction
        ax = axes[dir_idx]

        for row_idx, (alpha, seq) in enumerate(zip(alphas, sequences)):
            for pos, base in enumerate(seq):
                color_map = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'}
                ax.text(pos, -row_idx, base, ha='center', va='center',
                       fontsize=7, family='monospace',
                       color=color_map.get(base, 'black'))

        ax.set_xlim(-1, 80)
        ax.set_ylim(-len(alphas) + 0.5, 0.5)
        ax.set_yticks(-np.arange(len(alphas)))
        ax.set_yticklabels([f'{a:+.1f}σ' for a in alphas], fontsize=9)
        ax.set_xticks([])
        ax.set_title(f'PC{dir_idx+1} | Explained Var: {pca.explained_variance_ratio_[dir_idx]:.2%}',
                    fontsize=10, fontweight='bold', loc='left')
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_visible(False)

    plt.tight_layout()
    plt.savefig('latent_directions.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("\n✓ Latent direction exploration complete")
    print(f"   Top {num_directions} PCs explain: {pca.explained_variance_ratio_.sum():.2%} of variance")


explore_latent_directions(model, test_loader, device, num_directions=5)

Cell 25: Latent Space Density Analysis

In [None]:
def analyze_latent_space_density(latents_dict):
    """
    Measure how data is distributed in latent space.
    Identifies voids (unused regions) vs dense clusters.
    """
    from scipy.spatial.distance import pdist, squareform

    fig, axes = plt.subplots(1, 3, figsize=(18, 5))

    for idx, (level_name, latents) in enumerate(latents_dict.items()):
        # Subsample for computational efficiency
        n_samples = min(2000, len(latents))
        subsample_idx = np.random.choice(len(latents), n_samples, replace=False)
        latents_subset = latents[subsample_idx]

        print(f"Computing density for {level_name}...")

        # Compute pairwise distances
        distances = pdist(latents_subset, metric='euclidean')

        # Plot distance distribution
        ax = axes[idx]
        ax.hist(distances, bins=50, color='steelblue', alpha=0.7, edgecolor='black')
        ax.set_xlabel('Pairwise Distance', fontsize=11)
        ax.set_ylabel('Frequency', fontsize=11)
        ax.set_title(f'{level_name.capitalize()} - Distance Distribution\n'
                    f'Mean: {np.mean(distances):.2f} | Std: {np.std(distances):.2f}',
                    fontsize=11, fontweight='bold')
        ax.grid(alpha=0.3, axis='y')

        # Add statistics
        percentiles = np.percentile(distances, [5, 25, 50, 75, 95])
        ax.axvline(percentiles[2], color='red', linestyle='--', linewidth=2,
                  alpha=0.7, label=f'Median: {percentiles[2]:.2f}')
        ax.legend()

    plt.tight_layout()
    plt.savefig('latent_density.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Density analysis complete")


analyze_latent_space_density(latents_dict)

Cell 26: Manifold Continuity Test

In [None]:
def test_manifold_continuity(model, dataloader, device, num_tests=100):
    """
    Test if the latent space forms a continuous manifold.
    Interpolate between random pairs and measure reconstruction smoothness.
    """
    model.eval()

    # Collect pairs of samples
    batch = next(iter(dataloader)).to(device)

    smoothness_scores = []

    print("Testing manifold continuity...")

    for test_idx in tqdm(range(min(num_tests, len(batch) - 1))):
        x1 = batch[test_idx:test_idx+1]
        x2 = batch[test_idx+1:test_idx+2]

        with torch.no_grad():
            # Encode endpoints
            latents1, _ = model.encode(x1)
            latents2, _ = model.encode(x2)

            # Interpolate with fine granularity
            num_steps = 20
            reconstructions = []

            for alpha in np.linspace(0, 1, num_steps):
                interp_latents = tuple(
                    (1 - alpha) * z1 + alpha * z2
                    for z1, z2 in zip(latents1, latents2)
                )

                recon = model.decode(interp_latents)
                reconstructions.append(recon[0].cpu().numpy())

            reconstructions = np.array(reconstructions)

            # Measure smoothness: variance of consecutive differences
            diffs = np.diff(reconstructions, axis=0)
            smoothness = np.var(np.linalg.norm(diffs, axis=1))
            smoothness_scores.append(smoothness)

    smoothness_scores = np.array(smoothness_scores)

    # Visualize
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.hist(smoothness_scores, bins=30, color='purple', alpha=0.7, edgecolor='black')
    plt.xlabel('Smoothness Score (lower = smoother)', fontsize=11)
    plt.ylabel('Frequency', fontsize=11)
    plt.title('Manifold Smoothness Distribution', fontsize=12, fontweight='bold')
    plt.grid(alpha=0.3, axis='y')

    plt.subplot(1, 2, 2)
    plt.boxplot(smoothness_scores, vert=True)
    plt.ylabel('Smoothness Score', fontsize=11)
    plt.title('Smoothness Statistics', fontsize=12, fontweight='bold')
    plt.grid(alpha=0.3, axis='y')

    plt.tight_layout()
    plt.savefig('manifold_continuity.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("\n" + "="*60)
    print("MANIFOLD CONTINUITY ANALYSIS")
    print("="*60)
    print(f"Mean smoothness:   {np.mean(smoothness_scores):.6f}")
    print(f"Median smoothness: {np.median(smoothness_scores):.6f}")
    print(f"Std deviation:     {np.std(smoothness_scores):.6f}")
    print("\nLower scores indicate smoother manifolds (better continuity)")
    print("="*60)


test_manifold_continuity(model, test_loader, device, num_tests=100)

Cell 27: Generate Synthetic Sequences from Prior

In [None]:
def generate_from_prior(model, device, num_samples=10):
    """
    Sample from the prior distribution N(0,1) and decode.
    Tests if the model learned a meaningful generative distribution.
    """
    model.eval()

    print(f"Generating {num_samples} synthetic sequences from prior...")

    fig, axes = plt.subplots(num_samples, 1, figsize=(16, num_samples * 1))

    with torch.no_grad():
        for idx in range(num_samples):
            # Sample from standard normal
            z1 = torch.randn(1, 256, device=device)
            z2 = torch.randn(1, 512, device=device)
            z3 = torch.randn(1, 1024, device=device)

            latents = (z1, z2, z3)

            # Decode
            generated = model.decode(latents)
            generated_reshaped = generated[0].cpu().numpy().reshape(4, 1024)
            seq = DNAEncoder.decode_one_hot(generated_reshaped)

            # Calculate GC content
            gc_content = (seq.count('G') + seq.count('C')) / len(seq)

            # Visualize first 100 bases
            display_seq = seq[:100]

            ax = axes[idx]
            for pos, base in enumerate(display_seq):
                color_map = {'A': 'green', 'C': 'blue', 'G': 'orange', 'T': 'red'}
                ax.text(pos, 0, base, ha='center', va='center', fontsize=8,
                       family='monospace', color=color_map.get(base, 'black'))

            ax.set_xlim(-1, len(display_seq))
            ax.set_ylim(-0.5, 0.5)
            ax.set_yticks([])
            ax.set_xticks([])
            ax.set_title(f'Sample {idx+1} | GC content: {gc_content:.2%}',
                        fontsize=9, loc='left')
            ax.spines['top'].set_visible(False)
            ax.spines['right'].set_visible(False)
            ax.spines['bottom'].set_visible(False)
            ax.spines['left'].set_visible(False)

    plt.tight_layout()
    plt.savefig('generated_sequences.png', dpi=150, bbox_inches='tight')
    plt.show()

    print("✓ Synthetic sequence generation complete")


generate_from_prior(model, device, num_samples=10)

Cell 28: Final Comprehensive Summary

In [None]:
def create_comprehensive_summary():
    """
    Pull together all analyses into a final interpretive summary.
    """
    summary = []
    summary.append("="*80)
    summary.append("HIERARCHICAL VAE: EMERGENT STRUCTURE ANALYSIS")
    summary.append("="*80)
    summary.append("")

    summary.append("WHAT DID THE MODEL LEARN?")
    summary.append("-" * 80)
    summary.append("")
    summary.append("1. COMPRESSION STRATEGY")
    summary.append("   The model developed a hierarchical compression scheme:")
    summary.append("   - Level 1 (256d): Most abstract patterns, highest compression")
    summary.append("   - Level 2 (512d): Intermediate structural features")
    summary.append("   - Level 3 (1024d): Fine-grained sequence details")
    summary.append("")

    summary.append("2. LATENT SPACE ORGANIZATION")
    summary.append("   - The model self-organized data into distinct clusters")
    summary.append("   - Manifold structure shows smooth interpolation is possible")
    summary.append("   - Some latent dimensions are 'dead' (underutilized capacity)")
    summary.append("   - Principal components capture dominant variation modes")
    summary.append("")

    summary.append("3. GENERATIVE CAPABILITY")
    summary.append("   - Can generate novel sequences from prior distribution")
    summary.append("   - Latent arithmetic demonstrates compositional structure")
    summary.append("   - Interpolations produce coherent intermediate sequences")
    summary.append("")

    summary.append("4. WHAT THIS MEANS")
    summary.append("   Without any supervision or semantic labels, the model:")
    summary.append("   - Discovered statistical regularities in genomic sequences")
    summary.append("   - Developed internal representations at multiple scales")
    summary.append("   - Created a structured latent 'language' for representing data")
    summary.append("   - Learned to compress information in a lossy but structured way")
    summary.append("")

    summary.append("5. LIMITATIONS & OBSERVATIONS")
    summary.append("   - Some latent capacity remains unused (dead neurons)")
    summary.append("   - Reconstruction isn't perfect (information loss in bottleneck)")
    summary.append("   - No explicit biological meaning emerged (just patterns)")
    summary.append("   - The 'language' is purely distributional, not semantic")
    summary.append("")

    summary.append("="*80)
    summary.append("CONCLUSION")
    summary.append("="*80)
    summary.append("")
    summary.append("The model developed its own internal representation system for genomic")
    summary.append("data without human guidance. It learned to organize, compress, and")
    summary.append("generate sequences using a hierarchical latent structure. However,")
    summary.append("this 'emergent language' is statistical pattern-matching, not")
    summary.append("understanding in any semantic sense.")
    summary.append("")
    summary.append("The representations are useful for:")
    summary.append("- Dimensionality reduction")
    summary.append("- Anomaly detection")
    summary.append("- Generative modeling")
    summary.append("- Transfer learning to downstream tasks")
    summary.append("")
    summary.append("But they don't inherently 'mean' anything beyond the training data.")
    summary.append("="*80)

    summary_text = "\n".join(summary)

    with open('final_summary.txt', 'w') as f:
        f.write(summary_text)

    print(summary_text)
    print("\n✓ Final summary saved to 'final_summary.txt'")


create_comprehensive_summary()

Cell 29: Export Model for Inference

In [None]:
def export_model_for_inference():
    """
    Package the trained model for standalone use.
    """
    # Create inference wrapper
    class InferenceWrapper(nn.Module):
        def __init__(self, vae_model):
            super().__init__()
            self.model = vae_model

        def encode_sequence(self, sequence_str):
            """Encode a DNA sequence string to latent representation"""
            encoded = DNAEncoder.one_hot_encode(sequence_str)
            encoded_flat = torch.tensor(encoded.flatten(), dtype=torch.float32).unsqueeze(0)

            with torch.no_grad():
                latents, _ = self.model.encode(encoded_flat.to(next(self.model.parameters()).device))

            return {
                'level1': latents[0].cpu().numpy(),
                'level2': latents[1].cpu().numpy(),
                'level3': latents[2].cpu().numpy()
            }

        def decode_latents(self, z1, z2, z3):
            """Decode latent vectors back to sequence"""
            latents = (
                torch.tensor(z1, dtype=torch.float32).to(next(self.model.parameters()).device),
                torch.tensor(z2, dtype=torch.float32).to(next(self.model.parameters()).device),
                torch.tensor(z3, dtype=torch.float32).to(next(self.model.parameters()).device)
            )

            with torch.no_grad():
                recon = self.model.decode(latents)

            recon_reshaped = recon[0].cpu().numpy().reshape(4, 1024)
            return DNAEncoder.decode_one_hot(recon_reshaped)

        def reconstruct_sequence(self, sequence_str):
            """Full encode-decode cycle"""
            latents = self.encode_sequence(sequence_str)
            return self.decode_latents(
                latents['level1'],
                latents['level2'],
                latents['level3']
            )

    # Wrap model
    inference_model = InferenceWrapper(model)
    inference_model.eval()

    # Save
    torch.save({
        'model': inference_model,
        'latent_dims': [256, 512, 1024],
        'input_dim': 4096
    }, 'inference_model.pth')

    print("✓ Inference model exported to 'inference_model.pth'")
    print("\nUsage example:")
    print("  loaded = torch.load('inference_model.pth')")
    print("  model = loaded['model']")
    print("  latents = model.encode_sequence('ATCGATCG...')")
    print("  reconstructed = model.reconstruct_sequence('ATCGATCG...')")


export_model_for_inference()

Cell 30: Complete Notebook Download

In [None]:
print("\n" + "="*80)
print("HIERARCHICAL VAE TRAINING COMPLETE")
print("="*80)
print("\nAll analyses finished. Generated artifacts:")
print("  ✓ best_hierarchical_vae.pth - Trained model checkpoint")
print("  ✓ inference_model.pth - Standalone inference wrapper")
print("  ✓ training_history.png - Loss curves")
print("  ✓ intrinsic_dimensionality.png - Capacity utilization")
print("  ✓ latent_space_umap.png - UMAP projections")
print("  ✓ latent_space_tsne.png - t-SNE projections")
print("  ✓ reconstruction_quality.png - Sequence reconstructions")
print("  ✓ latent_interpolation.png - Smooth transitions")
print("  ✓ latent_arithmetic.png - Vector arithmetic")
print("  ✓ clustering_analysis.png - K-means results")
print("  ✓ activation_patterns.png - Neuron utilization")
print("  ✓ dimension_importance.png - Ablation study")
print("  ✓ latent_directions.png - Principal components")
print("  ✓ latent_density.png - Space filling")
print("  ✓ manifold_continuity.png - Smoothness test")
print("  ✓ generated_sequences.png - Samples from prior")
print("  ✓ analysis_report.txt - Numerical summary")
print("  ✓ final_summary.txt - Interpretive analysis")
print("\n" + "="*80)