In [None]:
import os, glob, cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import timm

In [None]:
def get_image_files(folder):
    return sorted(glob.glob(os.path.join(folder, "*.*")))

def get_chunk(file_list, chunk_index, total_chunks=14):
    n = len(file_list)
    chunk_size = n // total_chunks
    start = chunk_index * chunk_size
    end = n if chunk_index == total_chunks - 1 else start + chunk_size
    return file_list[start:end]

def read_image(path, target_size=(32,32)):
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(f"Image not found: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, target_size)
    return img.astype(np.float32) / 255.0


In [None]:
class ImageStatsTransform:
    def __call__(self, x):
        mean = x.mean(dim=[1,2], keepdim=True)
        std = x.std(dim=[1,2], keepdim=True) + 1e-6
        return (x - mean) / std

def compute_physics_features_tensor(x, eps=1e-6):
    ecal = x[:, 0:1, :, :]
    hcal = x[:, 1:2, :, :]
    tracks = x[:, 2:3, :, :]
    ratio = torch.mean(ecal / (hcal + eps), dim=[2,3])
    mean_tracks = torch.mean(tracks, dim=[2,3])
    diff = torch.mean(ecal - hcal, dim=[2,3])
    norm_diff = torch.mean(torch.abs(ecal - hcal) / (ecal + hcal + eps), dim=[2,3])
    return torch.cat([ratio, mean_tracks, diff, norm_diff], dim=1)

In [None]:
class ChannelWiseFPN(nn.Module):
    def __init__(self, in_channels=3, out_channels=16):
        super(ChannelWiseFPN, self).__init__()
        self.conv1x1 = nn.Conv2d(1, out_channels, kernel_size=1, padding=0)
        self.conv3x3 = nn.Conv2d(1, out_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(1, out_channels, kernel_size=5, padding=2)
        self.fuse_conv = nn.Sequential(
            nn.Conv2d(3*out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1)
        )
    def forward(self, x):
        outputs = []
        for i in range(x.shape[1]):
            xi = x[:, i:i+1, :, :]
            feat1 = self.conv1x1(xi)
            feat2 = self.conv3x3(xi)
            feat3 = self.conv5x5(xi)
            cat_feats = torch.cat([feat1, feat2, feat3], dim=1)
            fused = self.fuse_conv(cat_feats)
            outputs.append(fused)
        return torch.cat(outputs, dim=1)

In [None]:
class QuarkGluonDataset(Dataset):
    def __init__(self, orig_files, rec_files, graph_files, transform=None):
        self.orig_files = orig_files
        self.rec_files = rec_files
        self.graph_files = graph_files
        self.transform = transform
        self.labels = np.random.randint(0,2,size=(len(orig_files),))
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        orig = read_image(self.orig_files[idx])
        rec = read_image(self.rec_files[idx])
        graph = read_image(self.graph_files[idx])
        orig = torch.tensor(orig).permute(2,0,1)
        rec = torch.tensor(rec).permute(2,0,1)
        graph = torch.tensor(graph).permute(2,0,1)
        if self.transform:
            orig = self.transform(orig)
            rec = self.transform(rec)
            graph = self.transform(graph)
        label = self.labels[idx]
        return {"original": orig, "reconstructed": rec, "graph": graph, "label": label}

In [None]:
class ViTWithChannelFPN(nn.Module):
    def __init__(self, num_classes=2):
        super(ViTWithChannelFPN, self).__init__()
        self.fpn_orig = ChannelWiseFPN(3, 16)
        self.fpn_diff = ChannelWiseFPN(3, 16)
        self.fpn_graph = ChannelWiseFPN(3, 16)
        self.vit_orig = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.vit_diff = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.vit_graph = timm.create_model('vit_base_patch16_224', pretrained=True)
        for net in [self.vit_orig, self.vit_diff, self.vit_graph]:
            net.patch_embed.proj = nn.Conv2d(48, net.patch_embed.proj.out_channels,
                                             kernel_size=net.patch_embed.proj.kernel_size,
                                             stride=net.patch_embed.proj.stride,
                                             padding=net.patch_embed.proj.padding)
            net.head = nn.Identity()
        self.classifier = nn.Linear(768 + 4, num_classes)
    def forward(self, x_orig, x_rec, x_graph):
        x_diff = x_orig - x_rec
        feat_orig = self.vit_orig(self.fpn_orig(x_orig))
        feat_diff = self.vit_diff(self.fpn_diff(x_diff))
        feat_graph = self.vit_graph(self.fpn_graph(x_graph))
        deep_feat = (feat_orig + feat_diff + feat_graph) / 3.0
        phys_feat = compute_physics_features_tensor(x_orig)
        combined_feat = torch.cat([deep_feat, phys_feat], dim=1)
        logits = self.classifier(combined_feat)
        return logits

In [None]:
def save_checkpoint(model, optimizer, epoch, filename="vit_checkpoint.pth"):
    torch.save({"epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()}, filename)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(model, optimizer, filename="vit_checkpoint.pth"):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Checkpoint loaded from {filename}, resuming at epoch {start_epoch}")
        return start_epoch
    print("No checkpoint found, starting from scratch.")
    return 0

In [None]:
def main():
    total_chunks = 14
    num_epochs = 5
    batch_size = 32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Update folder paths according to new dataset structure.
    orig_folder = "/kaggle/input/genie-extracted-dataset"
    rec_folders = [
        "/kaggle/input/genie-output-part-1/reconstructions", 
        "/kaggle/input/genie-common-task-1-output-part-2final/reconstructions"
    ]
    graph_folders = [
        "/kaggle/input/common-task-2-dataset-part-1/processed_jet_graphs",
        "/kaggle/input/output-part-2-of-task-2/processed_jet_graphs",
        "/kaggle/input/output-part-3-of-task-2/processed_jet_graphs",
        "/kaggle/input/part-4-task-2-output/processed_jet_graphs"
    ]
    
    # Get original image files.
    orig_files_all = get_image_files(orig_folder)
    
    # Gather reconstructed image files from both parts.
    rec_files_all = []
    for folder in rec_folders:
        rec_files_all.extend(get_image_files(folder))
    
    # Gather graph files from all parts, filtering out the specified files.
    graph_files_all = []
    for folder in graph_folders:
        files = get_image_files(folder)
        ignore_files = []
        if "common-task-2-dataset-part-1" in folder:
            ignore_files.append("processed_chunk_120000_130000.pt")
        if "output-part-2-of-task-2" in folder:
            ignore_files.append("processed_chunk_40000_50000.pt")
        if "output-part-3-of-task-2" in folder:
            ignore_files.append("processed_chunk_80000_90000.pt")
        filtered_files = [f for f in files if os.path.basename(f) not in ignore_files]
        graph_files_all.extend(filtered_files)
    
    transform = ImageStatsTransform()
    model = ResNetWithChannelFPN(num_classes=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    start_epoch = load_checkpoint(model, optimizer, filename="resnet_checkpoint.pth")
    
    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        # Process data in 14 chunks.
        for chunk in range(total_chunks):
            print(f"  Processing chunk {chunk+1}/{total_chunks}")
            orig_chunk = get_chunk(orig_files_all, chunk, total_chunks)
            rec_chunk = get_chunk(rec_files_all, chunk, total_chunks)
            graph_chunk = get_chunk(graph_files_all, chunk, total_chunks)
            
            dataset = QuarkGluonDataset(orig_chunk, rec_chunk, graph_chunk, transform=transform)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
            
            model.train()
            for batch in dataloader:
                x_orig = batch["original"].to(device)
                x_rec = batch["reconstructed"].to(device)
                x_graph = batch["graph"].to(device)
                y = torch.tensor(batch["label"]).to(device)
                
                optimizer.zero_grad()
                outputs = model(x_orig, x_rec, x_graph)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()
            
            torch.cuda.empty_cache()
        
        save_checkpoint(model, optimizer, epoch, filename="resnet_checkpoint.pth")
        print(f"Epoch {epoch+1} complete.")
    
    print("Training complete.")


In [None]:
if __name__ == "__main__":
    main()