In [None]:

# =============================================================================
# Full Pipeline with Custom Combination of Peak & Frequency via a Neural Network
# 
# Now includes:
#   • BatchNorm on raw ViT pattern embeddings
#   • BatchNorm on raw frequency input
#   • BatchNorm inside the combiner MLP
#   • Remaining LayerNorms (in projection heads and freq MLP) retained
#   • Dropout for regularization
#   • Optional mixed precision (torch.amp)
#   • Extended warmup scheduler, early stopping, and checkpointing
#
# Requirements:
#   • Python 3.8+
#   • PyTorch 1.12+
#   • vit-pytorch>=0.25.6
#   • numpy
#   • pandas
#   • tqdm
#
# Before running:
#   pip install torch torchvision vit-pytorch>=0.25.6 numpy pandas tqdm
#   Ensure get_clip_encoding(chrom, start, end) is on your PYTHONPATH and returns:
#       seq_h_out (5×32), top_out (10×100), top_amount (10,)
# =============================================================================

import os
import random
import math
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR

from vit_pytorch import ViT

# Optional mixed precision
from torch.amp import GradScaler, autocast

# Assume get_clip_encoding is defined elsewhere and on PYTHONPATH:
# from your_module import get_clip_encoding


# -------------------- Flags & Hyperparameters --------------------
LOAD_DATASET    = True      # If True, load from saved .npy files instead of generating
SAVE_DATASET    = False     # If True (and generating), save arrays to .npy after generation

n_epochs        = 10000
batch_size      = 32 * 1
learning_rate   = 1e-3               # you can reduce further if needed
weight_decay    = 5e-8               # increased weight decay
dropout_prob    = 0.3                # increased dropout
warmup_epochs   = 50                 # longer warmup
USE_AMP         = False              # set to True to enable mixed precision

patience        = 1000                # early stopping patience
# -------------------------------------------------------------------

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ==============================================================================
# 1. Load promoter coordinates from BED file
# ==============================================================================
promoter_bed = 'promoter_view_test.bed'  # replace with actual path
bed_df = pd.read_csv(
    promoter_bed,
    sep='\t',
    header=None,
    usecols=[0, 1, 2],
    names=['chrom', 'start', 'end']
)







# ==============================================================================
# 2.  load  intervals ±25 kb of promoter start
# ==============================================================================

seq_h_dataset      = np.load('seq_h_dataset.npy')       # shape (n_samples, 5, 32)
top_out_dataset    = np.load('top_out_dataset.npy')     # shape (n_samples, 10, 100)
top_amount_dataset = np.load('top_amount_dataset.npy')  # shape (n_samples, 10)
n_samples       = seq_h_dataset.shape[0]


# ==============================================================================
# 3. Split into Training (80 %) and Validation (20 %) Sets
# ==============================================================================
indices = np.arange(n_samples)
np.random.seed(42)
np.random.shuffle(indices)

split     = int(0.8 * n_samples)
train_idx = indices[:split]
val_idx   = indices[split:]

train_seq_h      = seq_h_dataset[train_idx]
train_top_out    = top_out_dataset[train_idx]
train_top_amount = top_amount_dataset[train_idx]

val_seq_h      = seq_h_dataset[val_idx]
val_top_out    = top_out_dataset[val_idx]
val_top_amount = top_amount_dataset[val_idx]

# ==============================================================================
# 4. Wrap in PyTorch Dataset and DataLoader
# ==============================================================================
class GenomicClipDataset(Dataset):
    def __init__(self, seq_h_array, top_array, top_amount_array):
        """
        seq_h_array:      numpy array of shape (N, 5, 32)
        top_array:        numpy array of shape (N, 10, 100)
        top_amount_array: numpy array of shape (N, 10)
        """
        assert seq_h_array.shape[0] == top_array.shape[0] == top_amount_array.shape[0]
        self.seq_h      = seq_h_array
        self.top        = top_array
        self.top_amount = top_amount_array

    def __len__(self):
        return self.seq_h.shape[0]

    def __getitem__(self, idx):
        dna  = torch.from_numpy(self.seq_h[idx]).float()       # (5, 32)
        pat  = torch.from_numpy(self.top[idx]).float()         # (10, 100)
        freq = torch.from_numpy(self.top_amount[idx]).float()  # (10,)
        pat = pat.unsqueeze(0)  # → (1, 10, 100)
        return dna, pat, freq

train_dataset = GenomicClipDataset(train_seq_h, train_top_out, train_top_amount)
val_dataset   = GenomicClipDataset(val_seq_h,   val_top_out,   val_top_amount)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    drop_last=False
)

# ==============================================================================
# 5. Define Improved CLIPModel (ViT + Projection Heads + BatchNorm + Dropout)
# ==============================================================================
class CLIPModel(nn.Module):
    def __init__(self,
                 dna_dim: int = 256,
                 dna_depth: int = 2,
                 dna_heads: int = 8,
                 pat_dim: int = 512,
                 pat_depth: int = 2,
                 pat_heads: int = 8,
                 latent_dim: int = 512,
                 dropout: float = 0.4):
        super(CLIPModel, self).__init__()

        # --- DNA Encoder (ViT) ---
        self.dna_encoder = ViT(
            image_size   = (5, 32),
            patch_size   = (5, 8),
            num_classes  = dna_dim,
            dim          = dna_dim,
            depth        = dna_depth,
            heads        = dna_heads,
            mlp_dim      = dna_dim * 2,
            channels     = 1,
            dropout      = dropout,
            emb_dropout  = dropout
        )
        self.dna_proj = nn.Sequential(
            nn.Linear(dna_dim, latent_dim),
            nn.LayerNorm(latent_dim),   # keep LayerNorm here
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
        )

        # --- Pattern Encoder (ViT) ---
        self.pat_encoder = ViT(
            image_size   = (10, 100),
            patch_size   = (2, 10),
            num_classes  = pat_dim,
            dim          = pat_dim,
            depth        = pat_depth,
            heads        = pat_heads,
            mlp_dim      = pat_dim * 2,
            channels     = 1,
            dropout      = dropout,
            emb_dropout  = dropout
        )
        # Replace LayerNorm on raw ViT output with BatchNorm1d
        self.pat_feat_bn = nn.BatchNorm1d(pat_dim, eps=1e-5, momentum=0.1)

        # --- Frequency branch normalization ---
        # Replace LayerNorm(10) with BatchNorm1d(10)
        self.freq_bn = nn.BatchNorm1d(10, eps=1e-5, momentum=0.1)

        self.freq_proj_input = nn.Linear(10, pat_dim)
        self.freq_mlp = nn.Sequential(
            nn.Linear(pat_dim, pat_dim),
            nn.LayerNorm(pat_dim),   # keep LayerNorm inside freq MLP
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(pat_dim, pat_dim),
            nn.LayerNorm(pat_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
        )

        # --- Combiner MLP: (pat_dim*2 → pat_dim) with BatchNorm after first linear
        self.combiner_mlp = nn.Sequential(
            nn.Linear(pat_dim * 2, pat_dim),
            nn.BatchNorm1d(pat_dim, eps=1e-5, momentum=0.1),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(pat_dim, pat_dim),
            nn.LayerNorm(pat_dim),   # final LayerNorm
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
        )

        # Final projection: (pat_dim → latent_dim)
        self.pat_proj = nn.Sequential(
            nn.Linear(pat_dim, latent_dim),
            nn.LayerNorm(latent_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
        )

        # Learnable temperature parameter (initialized to log(1/0.07))
        self.logit_scale_param = nn.Parameter(torch.ones([]) * math.log(1/0.07))

        # Xavier initialization for all Linear layers
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, dna_tensor, pat_tensor, freq_tensor):
        """
        dna_tensor:  (B, 5, 32)
        pat_tensor:  (B, 1, 10, 100)
        freq_tensor: (B, 10) ← raw top_amount
        Returns:
            dna_latent: (B, latent_dim)
            pat_latent: (B, latent_dim)
            logit_scale: scalar > 0
        """
        # --- DNA branch ---
        dna_input   = dna_tensor.unsqueeze(1)          # (B, 1, 5, 32)
        dna_feat    = self.dna_encoder(dna_input)      # (B, dna_dim)
        dna_latent  = self.dna_proj(dna_feat)          # (B, latent_dim)

        # --- Pattern branch ---
        pat_feat       = self.pat_encoder(pat_tensor)  # (B, pat_dim)
        # BatchNorm1d expects input shape (B, C), so no transpose needed
        pat_feat_norm  = self.pat_feat_bn(pat_feat)    # (B, pat_dim)

        # --- Frequency branch ---
        freq_normed    = self.freq_bn(freq_tensor)     # (B, 10)
        freq_proj      = self.freq_proj_input(freq_normed)  # (B, pat_dim)
        freq_feat      = self.freq_mlp(freq_proj)            # (B, pat_dim)

        # --- Combine pattern + frequency ---
        combined   = torch.cat([pat_feat_norm, freq_feat], dim=1)  # (B, 2*pat_dim)
        comb_feat  = self.combiner_mlp(combined)                   # (B, pat_dim)
        pat_latent = self.pat_proj(comb_feat)                       # (B, latent_dim)

        # Clamp logit_scale to avoid numerical explosion
        logit_scale = torch.clamp(self.logit_scale_param, -5.0, 5.0)
        logit_scale = logit_scale.exp()

        return dna_latent, pat_latent, logit_scale


# Instantiate model and move to device
model = CLIPModel(
    dna_dim    = 256, dna_depth    = 8, dna_heads    = 8,
    pat_dim    = 512, pat_depth    = 8, pat_heads    = 8,
    latent_dim = 32,
    dropout    = dropout_prob
).to(device)

# ==============================================================================
# 6. Contrastive Loss (InfoNCE with Label Smoothing)
# ==============================================================================
def contrastive_loss(dna_latent, pat_latent, logit_scale, smoothing: float = 0.1):
    """
    dna_latent:   (B, D)
    pat_latent:   (B, D)
    logit_scale:  scalar
    smoothing:    float in [0,1]
    Returns:
        average of image-to-text and text-to-image KLDiv-based losses.
    """
    B, D = dna_latent.size()
    eps = 1e-8

    # normalize embeddings
    dna_norm = dna_latent / (dna_latent.norm(dim=1, keepdim=True) + eps)
    pat_norm = pat_latent / (pat_latent.norm(dim=1, keepdim=True) + eps)

    logits = logit_scale * torch.matmul(dna_norm, pat_norm.t())  # (B, B)
    labels = torch.arange(B, device=logits.device)
    n_classes = logits.size(1)

    # create smoothed targets
    with torch.no_grad():
        smooth_target = torch.full((B, n_classes),
                                   smoothing / (n_classes - 1),
                                   device=logits.device)
        smooth_target[torch.arange(B), labels] = 1.0 - smoothing

    # KLDiv between log_softmax and smoothed target
    loss_i2p = F.kl_div(F.log_softmax(logits, dim=1), smooth_target, reduction="batchmean")
    loss_p2i = F.kl_div(F.log_softmax(logits.t(), dim=1), smooth_target, reduction="batchmean")
    return (loss_i2p + loss_p2i) / 2.0

# ==============================================================================
# 7. Training & Validation Loop with Optional AMP, Extended Warmup, Early Stopping
# ==============================================================================
optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

total_steps  = n_epochs * len(train_loader)
warmup_steps = warmup_epochs * len(train_loader)

def lr_lambda(step):
    if step < warmup_steps:
        return float(step) / float(max(1, warmup_steps))
    return 0.5 * (1.0 + math.cos(math.pi * (step - warmup_steps) / (total_steps - warmup_steps)))

scheduler = LambdaLR(optimizer, lr_lambda=lr_lambda)

# Mixed-precision scaler (enabled only if USE_AMP=True)
scaler = GradScaler(enabled=USE_AMP)

# Print initial frequency statistics for debugging
freq_all = torch.from_numpy(top_amount_dataset)
print(
    "Initial freq stats:",
    f"min={freq_all.min().item():.1e}, max={freq_all.max().item():.1e},",
    f"mean={freq_all.mean().item():.1e}, std={freq_all.std().item():.1e}"
)

best_val_loss = float('inf')
no_improve    = 0
checkpoint_dir = "checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

for epoch in range(1, n_epochs + 1):
    # --- Training Phase ---
    model.train()
    train_loss_accum = 0.0
    train_bar = tqdm(train_loader, desc=f"Epoch {epoch}/{n_epochs} [Train]", leave=False)

    for step, (dna_batch, pat_batch, freq_batch) in enumerate(train_bar):
        dna_batch = dna_batch.to(device)
        pat_batch = pat_batch.to(device)
        freq_batch= freq_batch.to(device)

        # 1) Check inputs for NaN or Inf
        if dna_batch.isnan().any() or dna_batch.isinf().any():
            raise RuntimeError("dna_batch contains NaN or Inf")
        if pat_batch.isnan().any() or pat_batch.isinf().any():
            raise RuntimeError("pat_batch contains NaN or Inf")
        if freq_batch.isnan().any() or freq_batch.isinf().any():
            raise RuntimeError("freq_batch contains NaN or Inf")

        optimizer.zero_grad()

        # 2) Forward pass (with or without AMP)
        if USE_AMP:
            with autocast():
                dna_latent, pat_latent, logit_scale = model(dna_batch, pat_batch, freq_batch)
                if dna_latent.isnan().any():
                    raise RuntimeError("dna_latent contains NaN")
                if pat_latent.isnan().any():
                    raise RuntimeError("pat_latent contains NaN")
                if logit_scale.isnan().any():
                    raise RuntimeError("logit_scale contains NaN")
                loss = contrastive_loss(dna_latent, pat_latent, logit_scale, smoothing=0.1)
        else:
            dna_latent, pat_latent, logit_scale = model(dna_batch, pat_batch, freq_batch)
            if dna_latent.isnan().any():
                raise RuntimeError("dna_latent contains NaN")
            if pat_latent.isnan().any():
                raise RuntimeError("pat_latent contains NaN")
            if logit_scale.isnan().any():
                raise RuntimeError("logit_scale contains NaN")
            loss = contrastive_loss(dna_latent, pat_latent, logit_scale, smoothing=0.1)

        # 3) Check loss for NaN
        if loss.isnan().any():
            raise RuntimeError("Loss is NaN. Inspect dna_latent, pat_latent, and logit_scale.")

        # 4) Backward pass with anomaly detection
        if USE_AMP:
            with torch.autograd.set_detect_anomaly(True):
                scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
        else:
            with torch.autograd.set_detect_anomaly(True):
                loss.backward()

        # 5) Gradient clipping and check for NaNs in gradients
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        for name, param in model.named_parameters():
            if param.grad is not None and param.grad.isnan().any():
                raise RuntimeError(f"Gradient for {name} contains NaN")

        # 6) Optimizer step
        if USE_AMP:
            scaler.step(optimizer)
            scaler.update()
        else:
            optimizer.step()

        scheduler.step()

        train_loss_accum += loss.item()
        current_lr = scheduler.get_last_lr()[0]
        train_bar.set_postfix(train_loss=f"{loss.item():.4f}", lr=f"{current_lr:.2e}")
    train_bar.close()

    avg_train_loss = train_loss_accum / len(train_loader)

    # --- Validation Phase ---
    model.eval()
    val_loss_accum = 0.0
    val_bar = tqdm(val_loader, desc=f"Epoch {epoch}/{n_epochs} [Val]", leave=False)

    with torch.no_grad():
        for dna_batch, pat_batch, freq_batch in val_bar:
            dna_batch = dna_batch.to(device)
            pat_batch = pat_batch.to(device)
            freq_batch= freq_batch.to(device)

            if USE_AMP:
                with autocast():
                    dna_latent, pat_latent, logit_scale = model(dna_batch, pat_batch, freq_batch)
                    loss = contrastive_loss(dna_latent, pat_latent, logit_scale, smoothing=0.1)
            else:
                dna_latent, pat_latent, logit_scale = model(dna_batch, pat_batch, freq_batch)
                loss = contrastive_loss(dna_latent, pat_latent, logit_scale, smoothing=0.1)

            val_loss_accum += loss.item()
            val_bar.set_postfix(val_loss=f"{loss.item():.4f}")
    val_bar.close()

    avg_val_loss = val_loss_accum / len(val_loader)
    print(
        f"Epoch {epoch}/{n_epochs} — "
        f"Avg Train Loss: {avg_train_loss:.4f} | "
        f"Avg Val Loss:   {avg_val_loss:.4f} | "
        f"LR: {scheduler.get_last_lr()[0]:.2e}"
    )

    # Early Stopping & Checkpoint
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        no_improve    = 0
        ckpt_path     = os.path.join(checkpoint_dir, "best_clip_model.pth")
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scheduler_state": scheduler.state_dict(),
            "best_val_loss": best_val_loss
        }, ckpt_path)
        print(f"  ► Saved new best checkpoint (Val Loss: {best_val_loss:.4f})")
    else:
        no_improve += 1
        if no_improve >= patience:
            print(f"No improvement for {patience} epochs. Stopping early at epoch {epoch}.")
            break

# ==============================================================================
# 8. Save Final Model Checkpoint (if not already saved in early stop)
# ==============================================================================
final_ckpt = os.path.join(checkpoint_dir, "last_epoch_clip_model.pth")
torch.save({
    "epoch": epoch,
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "scheduler_state": scheduler.state_dict(),
    "best_val_loss": best_val_loss
}, final_ckpt)
print(f"Training complete. Final checkpoint saved to {final_ckpt}")
