In [1]:
import os, glob
import numpy as np
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models

In [None]:
def get_image_files(folder):
    # Return sorted list of image file paths
    return sorted(glob.glob(os.path.join(folder, "*.*")))

In [None]:
def get_chunk(file_list, chunk_index, total_chunks=14):
    n = len(file_list)
    chunk_size = n // total_chunks
    start = chunk_index * chunk_size
    if chunk_index == total_chunks - 1:
        end = n
    else:
        end = start + chunk_size
    return file_list[start:end]


In [None]:
def read_image(path, target_size=(32,32)):
    # Read image using cv2, convert BGR->RGB, and resize
    img = cv2.imread(path)
    if img is None:
        raise FileNotFoundError(f"Image not found: {path}")
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, target_size)
    # Normalize pixel values to [0,1]
    return img.astype(np.float32) / 255.0


In [None]:
class ImageStatsTransform:
    def __call__(self, x):
        # x: (C, H, W) torch.Tensor
        mean = x.mean(dim=[1,2], keepdim=True)
        std = x.std(dim=[1,2], keepdim=True) + 1e-6
        return (x - mean) / std

In [3]:
def compute_physics_features_tensor(x, eps=1e-6):
    # x: (B, 3, H, W) where channels represent [ecal, hcal, tracks]
    ecal = x[:, 0:1, :, :]
    hcal = x[:, 1:2, :, :]
    tracks = x[:, 2:3, :, :]
    ratio = torch.mean(ecal / (hcal + eps), dim=[2,3])
    mean_tracks = torch.mean(tracks, dim=[2,3])
    diff = torch.mean(ecal - hcal, dim=[2,3])
    norm_diff = torch.mean(torch.abs(ecal - hcal) / (ecal + hcal + eps), dim=[2,3])
    return torch.cat([ratio, mean_tracks, diff, norm_diff], dim=1)  # (B, 4)

In [4]:
class ChannelWiseFPN(nn.Module):
    def __init__(self, in_channels=3, out_channels=16):
        super(ChannelWiseFPN, self).__init__()
        self.conv1x1 = nn.Conv2d(1, out_channels, kernel_size=1, padding=0)
        self.conv3x3 = nn.Conv2d(1, out_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(1, out_channels, kernel_size=5, padding=2)
        self.fuse_conv = nn.Sequential(
            nn.Conv2d(3*out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=1)
        )
    def forward(self, x):
        # x: (B, in_channels, H, W)
        outputs = []
        for i in range(x.shape[1]):
            xi = x[:, i:i+1, :, :]
            feat1 = self.conv1x1(xi)
            feat2 = self.conv3x3(xi)
            feat3 = self.conv5x5(xi)
            cat_feats = torch.cat([feat1, feat2, feat3], dim=1)
            fused = self.fuse_conv(cat_feats)
            outputs.append(fused)
        return torch.cat(outputs, dim=1)  # (B, in_channels*out_channels, H, W)

In [5]:
class QuarkGluonDataset(Dataset):
    def __init__(self, orig_files, rec_files, graph_files, transform=None):
        # Lists of file paths (assumed same order for corresponding samples)
        self.orig_files = orig_files
        self.rec_files = rec_files
        self.graph_files = graph_files
        self.transform = transform
        # For demonstration, labels are randomly assigned. Replace with actual labels.
        self.labels = np.random.randint(0,2,size=(len(orig_files),))
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        orig = read_image(self.orig_files[idx])
        rec = read_image(self.rec_files[idx])
        graph = read_image(self.graph_files[idx])
        # Convert to tensor and rearrange to (C, H, W)
        orig = torch.tensor(orig).permute(2,0,1)
        rec = torch.tensor(rec).permute(2,0,1)
        graph = torch.tensor(graph).permute(2,0,1)
        if self.transform:
            orig = self.transform(orig)
            rec = self.transform(rec)
            graph = self.transform(graph)
        label = self.labels[idx]
        return {"original": orig, "reconstructed": rec, "graph": graph, "label": label}


In [6]:
def load_data_chunk(chunk_index):
    # Dummy loader: replace with your actual data loader.
    num_samples = 100
    H, W = 32, 32
    original = np.random.rand(num_samples, 3, H, W).astype(np.float32)
    reconstructed = np.random.rand(num_samples, 3, H, W).astype(np.float32)
    graph = np.random.rand(num_samples, 3, H, W).astype(np.float32)
    labels = np.random.randint(0, 2, size=(num_samples,))
    return original, reconstructed, graph, labels


In [7]:
class ResNetWithChannelFPN(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetWithChannelFPN, self).__init__()
        # Three FPN modules for original, difference, and graph.
        self.fpn_orig = ChannelWiseFPN(3, 16)
        self.fpn_diff = ChannelWiseFPN(3, 16)
        self.fpn_graph = ChannelWiseFPN(3, 16)
        # Three ResNet18 backbones (pretrained), modified to accept 48 channels.
        self.resnet_orig = models.resnet18(pretrained=True)
        self.resnet_diff = models.resnet18(pretrained=True)
        self.resnet_graph = models.resnet18(pretrained=True)
        for branch in [self.resnet_orig, self.resnet_diff, self.resnet_graph]:
            branch.conv1 = nn.Conv2d(48, 64, kernel_size=7, stride=2, padding=3, bias=False)
            branch.fc = nn.Identity()  # Feature output: 512-dim.
        # Final classifier: fuse deep features (512) with physics features (4).
        self.classifier = nn.Linear(512 + 4, num_classes)
    def forward(self, x_orig, x_rec, x_graph):
        # Compute difference image.
        x_diff = x_orig - x_rec
        feat_orig = self.resnet_orig(self.fpn_orig(x_orig))
        feat_diff = self.resnet_diff(self.fpn_diff(x_diff))
        feat_graph = self.resnet_graph(self.fpn_graph(x_graph))
        deep_feat = (feat_orig + feat_diff + feat_graph) / 3.0
        phys_feat = compute_physics_features_tensor(x_orig)
        combined_feat = torch.cat([deep_feat, phys_feat], dim=1)
        logits = self.classifier(combined_feat)
        return logits


In [8]:
def save_checkpoint(model, optimizer, epoch, filename="resnet_checkpoint.pth"):
    torch.save({"epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict()}, filename)
    print(f"Checkpoint saved to {filename}")

def load_checkpoint(model, optimizer, filename="resnet_checkpoint.pth"):
    if os.path.isfile(filename):
        checkpoint = torch.load(filename)
        model.load_state_dict(checkpoint["model_state"])
        optimizer.load_state_dict(checkpoint["optimizer_state"])
        start_epoch = checkpoint["epoch"] + 1
        print(f"Checkpoint loaded from {filename}, resuming at epoch {start_epoch}")
        return start_epoch
    print("No checkpoint found, starting from scratch.")
    return 0


In [9]:
def main():
    total_chunks = 14
    num_epochs = 5
    batch_size = 32
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Folders in /kaggle/input/ (update folder names as needed)
    orig_folder = "/kaggle/input/original_images"
    rec_folder = "/kaggle/input/reconstructed_images"
    graph_folder = "/kaggle/input/graph_images"
    
    # Get list of image files from each folder.
    orig_files_all = get_image_files(orig_folder)
    rec_files_all = get_image_files(rec_folder)
    graph_files_all = get_image_files(graph_folder)
    
    transform = ImageStatsTransform()
    model = ResNetWithChannelFPN(num_classes=2).to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    criterion = nn.CrossEntropyLoss()
    
    start_epoch = load_checkpoint(model, optimizer, filename="resnet_checkpoint.pth")
    for epoch in range(start_epoch, num_epochs):
        print(f"Epoch {epoch+1}/{num_epochs}")
        # Loop through 14 chunks.
        for chunk in range(total_chunks):
            print(f"  Processing chunk {chunk+1}/{total_chunks}")
            # Get the chunk of file names.
            orig_chunk = get_chunk(orig_files_all, chunk, total_chunks)
            rec_chunk = get_chunk(rec_files_all, chunk, total_chunks)
            graph_chunk = get_chunk(graph_files_all, chunk, total_chunks)
            dataset = QuarkGluonDataset(orig_chunk, rec_chunk, graph_chunk, transform=transform)
            dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
            model.train()
            for batch in dataloader:
                x_orig = batch["original"].to(device)
                x_rec = batch["reconstructed"].to(device)
                x_graph = batch["graph"].to(device)
                y = torch.tensor(batch["label"]).to(device)
                optimizer.zero_grad()
                outputs = model(x_orig, x_rec, x_graph)
                loss = criterion(outputs, y)
                loss.backward()
                optimizer.step()
            torch.cuda.empty_cache()
        save_checkpoint(model, optimizer, epoch, filename="resnet_checkpoint.pth")
        print(f"Epoch {epoch+1} complete.")
    print("Training complete.")

In [10]:
if __name__ == "__main__":
    main()

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 166MB/s]


No checkpoint found, starting from scratch.
Epoch 1/5
  Loading chunk 1/14


  y = torch.tensor(batch["label"]).to(device)


  Loading chunk 2/14


KeyboardInterrupt: 