In [1]:
import os
import gc
import math
import shutil
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import _LRScheduler
from tqdm import tqdm
import matplotlib.pyplot as plt
import random


In [None]:
from tqdm.notebook import tqdm

In [2]:
class ChunkDataset(Dataset):
    def __init__(self, npz_file, key='X_jets'):
        print(f"Loading chunk: {npz_file}")
        with np.load(npz_file) as data:
            self.images = data[key]  # shape: (N,125,125,3)
            # All other keys are considered properties.
            self.props = {k: data[k] for k in data.files if k != key}
        if self.images.max() > 1.0:
            self.images = self.images.astype('float32') / 255.0
        self.len = self.images.shape[0]
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        image = self.images[idx].astype('float32')
        image = np.transpose(image, (2, 0, 1))  # (3,125,125)
        properties = {k: self.props[k][idx] for k in self.props}
        return torch.tensor(image, dtype=torch.float32), properties

In [3]:
class CosineAnnealingWarmupScheduler(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, eta_min=0, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.eta_min = eta_min
        super(CosineAnnealingWarmupScheduler, self).__init__(optimizer, last_epoch)
    
    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            # Warmup phase: linearly scale from 0 to base_lr
            return [base_lr * (self.last_epoch + 1) / self.warmup_epochs for base_lr in self.base_lrs]
        else:
            # Cosine annealing phase
            cos_epoch = self.last_epoch - self.warmup_epochs
            total_cos = self.max_epochs - self.warmup_epochs
            return [self.eta_min + (base_lr - self.eta_min) *
                    (1 + math.cos(math.pi * cos_epoch / total_cos)) / 2
                    for base_lr in self.base_lrs]


In [4]:
class ChannelEncoder(nn.Module):

    def __init__(self, in_channels=1, base_channels=32):
        super(ChannelEncoder, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(in_channels, base_channels, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels, 63, 63)
        self.layer2 = nn.Sequential(
            nn.Conv2d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels*2, 32, 32)
        self.layer3 = nn.Sequential(
            nn.Conv2d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels*4, 16, 16)
    
    def forward(self, x):
        skip1 = self.layer1(x)
        skip2 = self.layer2(skip1)
        latent = self.layer3(skip2)
        return latent, skip1, skip2

In [5]:
class SelfAttention(nn.Module):

    def __init__(self, in_channels):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.key   = nn.Conv2d(in_channels, in_channels // 8, kernel_size=1)
        self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))
    
    def forward(self, x):
        B, C, H, W = x.shape
        proj_query = self.query(x).view(B, -1, H*W)
        proj_key   = self.key(x).view(B, -1, H*W)
        energy = torch.bmm(proj_query.permute(0, 2, 1), proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value(x).view(B, -1, H*W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1)).view(B, C, H, W)
        return self.gamma * out + x


In [6]:
class CrossChannelFusion(nn.Module):
 
    def __init__(self, in_channels, out_channels):
        super(CrossChannelFusion, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.attn = SelfAttention(out_channels)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.attn(x)
        return x



In [7]:
class Decoder(nn.Module):
 
    def __init__(self, base_channels=32):
        super(Decoder, self).__init__()
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*8, base_channels*4, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(base_channels*4),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels*4, 32,32)
        
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*(4 + 2*3), base_channels*2, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(base_channels*2),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels*2, ~63,63)
        
        self.up3 = nn.Sequential(
            nn.ConvTranspose2d(base_channels*(2 + 3), base_channels, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(base_channels),
            nn.ReLU(inplace=True)
        )  # -> (B, base_channels, ~125,125)
        
        self.up4 = nn.Sequential(
            nn.ConvTranspose2d(base_channels, 3, kernel_size=3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )  # -> (B, 3, 250,250)
    
    def forward(self, latent, skips_e, skips_h, skips_t):
        x = self.up1(latent)  # (B, base_channels*4, 32,32)
        skip2_fused = torch.cat([skips_e[1], skips_h[1], skips_t[1]], dim=1)  # (B, base_channels*6,32,32)
        x = torch.cat([x, skip2_fused], dim=1)  # (B, base_channels*10,32,32)
        x = self.up2(x)  # (B, base_channels*2, ~H, ~W)  Expected: ~ (B, base_channels*2, 64,64) or 63x63
        
        # Fuse skip1 features: they are expected to be (B, base_channels, 63,63)
        skip1_fused = torch.cat([skips_e[0], skips_h[0], skips_t[0]], dim=1)  # (B, base_channels*3, 63,63)
        # If spatial dimensions do not match, interpolate skip1_fused to match x
        if x.shape[2] != skip1_fused.shape[2] or x.shape[3] != skip1_fused.shape[3]:
            skip1_fused = F.interpolate(skip1_fused, size=(x.shape[2], x.shape[3]), mode='bilinear', align_corners=False)
        x = torch.cat([x, skip1_fused], dim=1)  # (B, base_channels*2 + base_channels*3, H, W)
        x = self.up3(x)  # (B, base_channels, ~125,125)
        x = self.up4(x)  # (B, 3, 250,250)
        x = F.interpolate(x, size=(125,125), mode='bilinear', align_corners=False)
        return x

In [8]:
class PhysicsInformedAutoencoder(nn.Module):
    def __init__(self, base_channels=32):
        super(PhysicsInformedAutoencoder, self).__init__()
        self.encoder_e = ChannelEncoder(1, base_channels)
        self.encoder_h = ChannelEncoder(1, base_channels)
        self.encoder_t = ChannelEncoder(1, base_channels)
        # Each encoder latent: (B, base_channels*4,16,16) → concatenated: (B, base_channels*12,16,16)
        self.fusion = CrossChannelFusion(base_channels*12, base_channels*8)
        self.decoder = Decoder(base_channels)
    
    def forward(self, x):
        # x: (B,3,125,125)
        x_e = x[:,0:1,:,:]  # ECAL-like
        x_h = x[:,1:2,:,:]  # HCAL-like
        x_t = x[:,2:3,:,:]  # Tracks
        
        latent_e, se1, se2 = self.encoder_e(x_e)
        latent_h, sh1, sh2 = self.encoder_h(x_h)
        latent_t, st1, st2 = self.encoder_t(x_t)
        
        skips_e = (se1, se2)
        skips_h = (sh1, sh2)
        skips_t = (st1, st2)
        
        latent_cat = torch.cat([latent_e, latent_h, latent_t], dim=1)  # (B, base_channels*12,16,16)
        fused_latent = self.fusion(latent_cat)  # (B, base_channels*8,16,16)
        out = self.decoder(fused_latent, skips_e, skips_h, skips_t)
        return out

In [9]:
def physics_informed_loss(output, target, alpha=0.7):

    mse_loss = F.mse_loss(output, target)
    energy_out = output.sum(dim=[1,2,3])
    energy_target = target.sum(dim=[1,2,3])
    energy_loss = F.l1_loss(energy_out, energy_target)
    total_loss = alpha * energy_loss + (1 - alpha) * mse_loss
    return total_loss


In [10]:
def train_on_chunk(model, train_loader, optimizer, device, alpha=0.7):
    model.train()
    running_loss = 0.0
    for batch in train_loader:
        images, _ = batch
        images = images.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = physics_informed_loss(outputs, images, alpha=alpha)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * images.size(0)
    return running_loss / len(train_loader.dataset)

def eval_on_chunk(model, val_loader, device, alpha=0.7):
    model.eval()
    total_loss = 0.0
    with torch.no_grad():
        for batch in val_loader:
            images, _ = batch
            images = images.to(device)
            outputs = model(images)
            loss = physics_informed_loss(outputs, images, alpha=alpha)
            total_loss += loss.item() * images.size(0)
    return total_loss / len(val_loader.dataset)

In [11]:
def chunk_based_training(
    chunk_dir,
    chunk_files,
    model,
    optimizer,
    scheduler,
    device,
    epochs=5,
    alpha_loss=0.7,
    train_ratio=0.8,
    batch_size=32,
    num_workers=2
):
    for epoch in range(epochs):
        epoch_train_loss = 0.0
        epoch_val_loss = 0.0
        total_train_samples = 0
        total_val_samples = 0
        
        print(f"\n=== Epoch {epoch+1}/{epochs} ===")
        for chunk_file in chunk_files:
            chunk_path = os.path.join(chunk_dir, chunk_file)
            dataset = ChunkDataset(chunk_path, key='X_jets')
            n_total = len(dataset)
            n_train = int(train_ratio * n_total)
            n_val = n_total - n_train
            train_ds, val_ds = random_split(dataset, [n_train, n_val])
            
            train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
            val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
            
            train_loss = train_on_chunk(model, train_loader, optimizer, device, alpha=alpha_loss)
            val_loss = eval_on_chunk(model, val_loader, device, alpha=alpha_loss)
            
            epoch_train_loss += train_loss * n_train
            epoch_val_loss += val_loss * n_val
            total_train_samples += n_train
            total_val_samples += n_val
            
            del dataset, train_ds, val_ds, train_loader, val_loader
            gc.collect()
        
        epoch_train_loss /= total_train_samples
        epoch_val_loss /= total_val_samples
        scheduler.step()
        print(f"Epoch {epoch+1} - Train Loss: {epoch_train_loss:.6f} | Val Loss: {epoch_val_loss:.6f}")
    print("Training complete after all epochs.")

In [12]:
def save_reconstructed_chunk(chunk_file, model, output_folder, device, batch_size=32, key='X_jets'):
    chunk_name = os.path.splitext(os.path.basename(chunk_file))[0]
    dataset = ChunkDataset(chunk_file, key=key)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    model.eval()
    reconstructions = []
    with torch.no_grad():
        for batch in tqdm(data_loader, desc=f"Inferencing {chunk_name}"):
            images, _ = batch
            images = images.to(device)
            output = model(images)  # (B,3,125,125)
            reconstructions.append(output.cpu().numpy())
    recon_all = np.concatenate(reconstructions, axis=0)  # (N,3,125,125)
    recon_all = np.transpose(recon_all, (0,2,3,1))       # (N,125,125,3)
    
    os.makedirs(output_folder, exist_ok=True)
    npz_save_path = os.path.join(output_folder, f"{chunk_name}_recon.npz")
    np.savez_compressed(npz_save_path, X_recon=recon_all)
    print(f"Saved reconstructed chunk to: {npz_save_path}")
    
    zip_base = os.path.splitext(npz_save_path)[0]
    zip_path = shutil.make_archive(zip_base, 'zip', output_folder, os.path.basename(npz_save_path))
    print(f"Zipped reconstructed chunk: {zip_path}")
    os.remove(npz_save_path)
    print("Removed unzipped npz file to save space.")
    
    del dataset, data_loader, reconstructions, recon_all
    gc.collect()

In [13]:
def visualize_reconstructions_chunk(chunk_file, model, device, num_samples=10, key='X_jets'):
    dataset = ChunkDataset(chunk_file, key=key)
    data_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=0)
    model.eval()
    originals_list = []
    recon_list = []
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            if i >= num_samples:
                break
            images, _ = batch
            images = images.to(device)
            output = model(images)  # (1,3,125,125)
            originals_list.append(images.cpu().squeeze(0))
            recon_list.append(output.cpu().squeeze(0))
    plt.figure(figsize=(20, num_samples*2))
    for i in range(num_samples):
        plt.subplot(num_samples, 2, 2*i+1)
        plt.imshow(originals_list[i].permute(1,2,0).numpy())
        plt.title("Original")
        plt.axis('off')
        plt.subplot(num_samples, 2, 2*i+2)
        plt.imshow(recon_list[i].permute(1,2,0).numpy())
        plt.title("Reconstructed")
        plt.axis('off')
    plt.tight_layout()
    plt.show()
    plt.show()

In [None]:
if __name__ == "__main__":
    # Define your chunk directory and get list of .npz chunk files
    chunk_dir = "/kaggle/input/genie-extracted-dataset"
    chunk_files = sorted([f for f in os.listdir(chunk_dir) if f.endswith(".npz") and f.startswith("chunk_")])
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = PhysicsInformedAutoencoder(base_channels=32).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    scheduler = CosineAnnealingWarmupScheduler(optimizer, warmup_epochs=5, max_epochs=20, eta_min=1e-5)
    
    epochs = 20
    alpha_loss = 0.7
    train_ratio = 0.9
    batch_size = 100
    num_workers = 2
    
    # Train model epoch-by-epoch, processing each chunk sequentially.
    chunk_based_training(
        chunk_dir=chunk_dir,
        chunk_files=chunk_files,
        model=model,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        epochs=epochs,
        alpha_loss=alpha_loss,
        train_ratio=train_ratio,
        batch_size=batch_size,
        num_workers=num_workers
    )
    
    # After training, run inference on all chunks except the last three
    if len(chunk_files) > 3:
        inference_files = chunk_files[:-3]
    else:
        inference_files = chunk_files
    
    output_folder = "/kaggle/working/reconstructions"
    for cf in inference_files:
        chunk_path = os.path.join(chunk_dir, cf)
        save_reconstructed_chunk(chunk_path, model, output_folder=output_folder, device=device, batch_size=batch_size)
    
    # Save final model weights
    final_model_path = "final_model.pth"
    torch.save(model.state_dict(), final_model_path)
    print(f"Model saved to {final_model_path}")



=== Epoch 1/20 ===
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_0_10000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_100000_110000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_10000_20000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_110000_120000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_120000_130000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_130000_139306.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_20000_30000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_30000_40000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_40000_50000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_50000_60000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_60000_70000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_70000_80000.npz
Loading chunk: /kaggle/input/genie-extracted-dataset/chunk_80000_90000.npz
L