# 📘 Common Task 1 – Variational Autoencoder on Jet Images (DeepFalcon GSoC 2025)

In [None]:
!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
!pip install -q torch-geometric
!pip install -q h5py imageio seaborn open3d tqdm

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision.utils import make_grid

In [None]:
class JetDataset(Dataset):
    def __init__(self, npz_path):
        data = np.load(npz_path)
        self.ecal = data['ecal']
        self.hcal = data['hcal']
        self.track = data['track']

    def __len__(self):
        return len(self.ecal)

    def __getitem__(self, idx):
        image = np.stack([self.ecal[idx], self.hcal[idx], self.track[idx]], axis=0)
        return torch.tensor(image, dtype=torch.float32)

In [None]:
data_path = '/content/jet_images_3ch.npz'
dataset = JetDataset(data_path)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 16, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(16, 32, 4, 2, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.fc_mu = nn.Linear(64*16*16, latent_dim)
        self.fc_logvar = nn.Linear(64*16*16, latent_dim)

    def forward(self, x):
        x = self.conv(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

In [None]:
class Decoder(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.fc = nn.Linear(latent_dim, 64*16*16)
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, z):
        x = self.fc(z).view(-1, 64, 16, 16)
        return self.deconv(x)

In [None]:
class VAE(nn.Module):
    def __init__(self, latent_dim=32):
        super().__init__()
        self.encoder = Encoder(latent_dim)
        self.decoder = Decoder(latent_dim)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        return x_recon, mu, logvar

In [None]:
def vae_loss(x, x_recon, mu, logvar):
    recon_loss = F.mse_loss(x_recon, x, reduction='sum')
    kl_div = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return recon_loss + kl_div

In [None]:
vae = VAE(latent_dim=32).cuda()
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-3)

num_epochs = 10
for epoch in range(num_epochs):
    vae.train()
    total_loss = 0
    for batch in dataloader:
        batch = batch.cuda()
        recon, mu, logvar = vae(batch)
        loss = vae_loss(batch, recon, mu, logvar)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {total_loss/len(dataset):.4f}")

In [None]:
vae.eval()
with torch.no_grad():
    test_batch = next(iter(dataloader)).cuda()
    recon_batch, _, _ = vae(test_batch)

    def show_images(original, reconstructed, n=8):
        original = original[:n].cpu()
        reconstructed = reconstructed[:n].cpu()
        fig, axs = plt.subplots(2, n, figsize=(n * 2, 4))
        for i in range(n):
            axs[0, i].imshow(np.transpose(original[i], (1, 2, 0)))
            axs[0, i].set_title("Original")
            axs[0, i].axis('off')
            axs[1, i].imshow(np.transpose(reconstructed[i], (1, 2, 0)))
            axs[1, i].set_title("Reconstructed")
            axs[1, i].axis('off')
        plt.tight_layout()
        plt.show()

    show_images(test_batch, recon_batch)

In [None]:
torch.save(vae.state_dict(), "vae_jet.pth")