In [1]:
! pip install gdown
! gdown --id 1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr -O filename
! pip install h5p

Downloading...
From (original): https://drive.google.com/uc?id=1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr
From (redirected): https://drive.google.com/uc?id=1WO2K-SfU2dntGU4Bb3IYBp9Rh7rtTYEr&confirm=t&uuid=6d14ea3e-1226-439f-8853-43010e2a1f2d
To: /kaggle/working/filename
100%|█████████████████████████████████████████| 701M/701M [00:04<00:00, 174MB/s]
[31mERROR: Could not find a version that satisfies the requirement h5p (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for h5p[0m[31m
[0m

In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import h5py
from tqdm import tqdm
import os
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import concurrent.futures
import gc

# ========================
# Configuration
# ========================
config = {
    'batch_size': 64,
    'num_epochs': 80,
    'initial_lr': 1e-3,
    'patience': 5,
    'min_lr': 1e-6,
    'num_workers': min(8, os.cpu_count()),  # Limit to 8 workers max
    'pin_memory': True,
    'persistent_workers': True,
    'sparse_threshold': 1e-6,
    'save_full_precision': True,
    'precision_decimals': 8
}



# Set numerical precision
torch.set_printoptions(precision=config['precision_decimals'])
np.set_printoptions(precision=config['precision_decimals'])

In [3]:
class JetAutoencoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, 3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Flatten(),
            nn.Linear(512*8*8, 4096),
            nn.ReLU()
        )
        
        # Decoder
        self.decoder = nn.Sequential(
            nn.Linear(4096, 512*8*8),
            nn.ReLU(),
            nn.Unflatten(1, (512, 8, 8)),
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, stride=2, padding=1, output_padding=1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.decoder(self.encoder(x))

In [4]:
# 3. Your custom loss function
class WeightedMSE(nn.Module):
    def __init__(self, weight_nonzero=1000.0):
        super().__init__()
        self.weight_nonzero = weight_nonzero  # Higher weight for non-zero pixels

    def forward(self, y_pred, y_true):
        # Create weight tensor: 1.0 for zeros, `weight_nonzero` for non-zeros
        weights = torch.where(y_true == 0, 
                            torch.tensor(1.0, device=y_true.device),
                            torch.tensor(self.weight_nonzero, device=y_true.device))
        
        # Calculate weighted MSE
        squared_error = (y_true - y_pred) ** 2
        weighted_loss = weights * squared_error
        return torch.mean(weighted_loss)

In [5]:
class HandleSparseImages:
    def __call__(self, img):
        sparse_mask = (img < config['sparse_threshold'])
        noise = torch.randn_like(img) * config['sparse_threshold'] * 0.1
        return torch.where(sparse_mask, img + noise, img)

In [6]:
# ========================
# Data Loading
# ========================
def load_data(i):
    print("Loading data with full precision...")
    
    end = i*25000
    first = end - 25000
    if end > 125000:
        end = 139306
        first = 125000

    # if i>1:
    #     del X_jets
    #     gc.collect()  # Force garbage collector to run immediately

    with h5py.File('/kaggle/working/filename', 'r') as f:
        X_jets = f['X_jets'][first:end].astype(np.float32)
    
    if config['save_full_precision']:
        #np.save('X_jets_original.npy', X_jets)
        print("Saved original data with full precision")
    
    transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: (x - x.min()) / (x.max() - x.min())),
        HandleSparseImages()
    ])
    
    def process_batch(batch):
        return torch.stack([transform(img) for img in batch])
    
    print("Processing images with all workers...")
    batches = np.array_split(X_jets, config['num_workers'])
    results = []
    
    with concurrent.futures.ThreadPoolExecutor(max_workers=config['num_workers']) as executor:
        futures = [executor.submit(process_batch, batch) for batch in batches]
        for future in tqdm(concurrent.futures.as_completed(futures), 
                         total=len(futures),
                         desc="Processing"):
            results.append(future.result())
    
    return torch.cat(results)

In [7]:
# Initialize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Initialize model
model = JetAutoencoder().to(device)

# 4. Load your trained model
model_path = '/kaggle/input/lhc-auto-encoder/pytorch/default/1/best_model.pth'
state_dict  = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)

criterion = WeightedMSE()
optimizer = optim.Adam(model.parameters(), lr=config['initial_lr'])
scheduler = ReduceLROnPlateau(optimizer, 'min', 
                            patience=config['patience'],
                            min_lr=config['min_lr'],
                            verbose=True)
writer = SummaryWriter()

# Training
best_loss = float('inf')
early_stop_counter = 0

Using device: cuda


  state_dict  = torch.load(model_path, map_location=device)


In [8]:
epoch = 0
for i in range(1,6):

    # Load data
    test_data = load_data(i)
    
    val_loader = DataLoader(test_data, batch_size=config['batch_size'],
                               num_workers=config['num_workers'],
                               pin_memory=config['pin_memory'])
    
    # Validation
    model.eval()
    val_loss = 0
    val_pbar = tqdm(val_loader,
                   desc=f'Epoch {epoch+1}/{config["num_epochs"]} [Val]')
    
    with torch.no_grad():
        for batch in val_pbar:
            batch = batch.to(device, non_blocking=True)
            outputs = model(batch)
            loss = criterion(outputs, batch)
            val_loss += loss.item()
            val_pbar.set_postfix({
                'val_loss': format(loss.item(), f".{config['precision_decimals']}f")
            })
    
    val_loss /= len(val_loader)
    scheduler.step(val_loss)
    
    writer.add_scalar('Loss/val', val_loss, epoch)
    writer.add_scalar('Learning Rate', optimizer.param_groups[0]['lr'], epoch)
    
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"Val Loss: {val_loss:.{config['precision_decimals']}f}")

    del test_data
    gc.collect()

Loading data with full precision...
Saved original data with full precision
Processing images with all workers...


Processing: 100%|██████████| 4/4 [00:27<00:00,  6.87s/it]
Epoch 1/80 [Val]: 100%|██████████| 391/391 [00:12<00:00, 31.14it/s, val_loss=0.01738904]



Epoch 1 Summary:
Val Loss: 0.01556302
Loading data with full precision...
Saved original data with full precision
Processing images with all workers...


Processing: 100%|██████████| 4/4 [00:24<00:00,  6.01s/it]
Epoch 1/80 [Val]: 100%|██████████| 391/391 [00:12<00:00, 32.19it/s, val_loss=0.02780784]



Epoch 1 Summary:
Val Loss: 0.02678649
Loading data with full precision...
Saved original data with full precision
Processing images with all workers...


Processing: 100%|██████████| 4/4 [00:24<00:00,  6.20s/it]
Epoch 1/80 [Val]: 100%|██████████| 391/391 [00:12<00:00, 30.94it/s, val_loss=0.02384840]



Epoch 1 Summary:
Val Loss: 0.02975228
Loading data with full precision...
Saved original data with full precision
Processing images with all workers...


Processing: 100%|██████████| 4/4 [00:23<00:00,  5.94s/it]
Epoch 1/80 [Val]: 100%|██████████| 391/391 [00:12<00:00, 31.84it/s, val_loss=0.03107654]



Epoch 1 Summary:
Val Loss: 0.03027023
Loading data with full precision...
Saved original data with full precision
Processing images with all workers...


Processing: 100%|██████████| 4/4 [00:22<00:00,  5.75s/it]
Epoch 1/80 [Val]: 100%|██████████| 391/391 [00:12<00:00, 31.54it/s, val_loss=0.02475519]



Epoch 1 Summary:
Val Loss: 0.03112698
