# ELECTORN/PHOTON CLASSIFICATION USING ResNet-15

### OBJECTIVE: 
#### Build a Deep Learning Model to Classify between electrons and photons using the provided datasets.
    - Preprocess the datasets 
    - Normalize and prepare the data for model building
    - Train a ResNet CNN model
    - Evaluation and Optimization

### REQUIRED LIBRARIES

In [None]:
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split
from torchsummary import summary

### LOADING THE DATASET

In [None]:
# electron dataset
electron_file = h5py.File("SingleElectronPt50_IMGCROPS_n249k_RHv1.hdf5", "r")  # open the file in read mode 
# photon dataset
photon_file = h5py.File("SinglePhotonPt50_IMGCROPS_n249k_RHv1.hdf5", "r")

# print dataset keys
print(list(electron_file.keys()))
print(list(photon_file.keys()))  

In [None]:
# extraction of image data and labels

with electron_file as f:
    X_electron = np.array(f["X"])  # images
    y_electron = np.array(f["y"])  # labels

with photon_file as f:
    X_photon = np.array(f["X"])  
    y_photon = np.array(f["y"]) 

### INSPECTING THE DATASET

In [None]:
# Print the shape of the datasets
print("Electron dataset shape:", X_electron.shape, y_electron.shape)
print("Photon dataset shape:", X_photon.shape, y_photon.shape)

In [None]:
# Print the  data types
print("Electron dataset data type:", X_electron.dtype, y_electron.dtype)
print("Photon dataset data type:", X_photon.dtype, y_photon.dtype)

In [None]:
# Label Distribution
print("Electron label distribution:", np.unique(y_electron, return_counts=True))
print("Photon label distribution:", np.unique(y_photon, return_counts=True))

Notes: 

    - Sample size total: 498,000 with a ratio of 1:1 for e:p, so the dataset is balanced in sample size.
    - Image Format: 32X32 with 2 channels
                - Channel 1: Hit energy (X[:, :, :, 0])
                - Channel 2: Hit time (X[:, :, :, 1])
    - Labels: 
        - Electrons: 1
        - Photons: 0
    

### EXPLORATORY DATA ANALYSIS (EDA)

#### VISUALIZING SOME SAMPLES

In [None]:
def plot_sample(data, title, index=0):
    energy = data[index, :, :, 0]  # First Channel (energy)
    time = data[index, :, :, 1]  # Second Channel (time)

    fig, ax = plt.subplots(1, 2, figsize=(8, 4))
    ax[0].imshow(energy, cmap="viridis")
    ax[0].set_title(f"{title} - Energy")
    ax[1].imshow(time, cmap="magma")
    ax[1].set_title(f"{title} - Time")
    plt.show()

plot_sample(X_electron, "Electron Sample", index=0)
plot_sample(X_photon, "Photon Sample", index=0)

Notes:

- Helps us in visualizing how the electrons and photons differ in energy deposition through the detector.
- These energy deposition patterns can help in classificaiton.

#### LABEL DISTRIBUTION: VISUALIZED

In [None]:
labels = np.concatenate([y_electron, y_photon])
sns.histplot(labels, bins=2)
plt.xticks([0, 1], ["Photon (0)", "Electron (1)"])
plt.title("Label Distribution")
plt.show()

### DATA PREPROCESSING

##### COMBINE AND NORMALIZE DATA

In [None]:
# Merging Datasets
X = np.concatenate([X_electron, X_photon], axis=0)
y = np.concatenate([np.ones_like(y_electron), np.zeros_like(y_photon)], axis=0)


# Normalizing energy and time channels
X[:, :, :, 0] /= np.max(X[:, :, :, 0])  # Energy
X[:, :, :, 1] /= np.max(X[:, :, :, 1])  # Time


print("Dataset Shape:", X.shape, y.shape)

##### TRAIN-TEST SPLIT

In [None]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y, random_state=42)
print("Train:", X_train.shape, y_train.shape)
print("Test:", X_test.shape, y_test.shape)

### MODEL BUILDING 

##### MODEL ARCHITECTURE

In [None]:
# Manage the problem of vanishing gradients 
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)  # Batch Normalization layer
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Identify Shortcut Connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        identity = self.shortcut(x)
        out = torch.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity  # skip connection 
        return torch.relu(out)
    
class ResNetModified(nn.Module):
    def __init__(self, num_classes=2):
        super(ResNetModified, self).__init__()
        self.conv1 = nn.Conv2d(2, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self.__make_layer(64, 128, stride=1)
        self.layer2 = self.__make_layer(128,256, stride=2)
        self.layer3 = self.__make_layer(256, 512, stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)
    
    def __make_layer(self, in_channels, out_channels, stride):
        return nn.Sequential(
            ResidualBlock(in_channels, out_channels, stride),
            ResidualBlock(out_channels, out_channels, 1)
        )
    
    def forward(self, x):
        x = torch.relu(self.bn1(self.conv1(x)))  # initial convolutional layer
        x = self.layer1(x)  # residual block 1
        x = self.layer2(x)  # residual block 2
        x = self.layer3(x)  # residual block 3
        x = self.avg_pool(x)  # Global average pooling
        x = torch.flatten(x, 1)  # flatten for FC layer
        return self.fc(x)  # Fully connected layer
    

# Instantiate the model
model = ResNetModified()
print(model)


### TRAINING THE MODEL

##### LOSS AND OPTIMIZER

In [None]:
criteria = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

##### TRAINING PHASE

In [None]:
import time 
import torch.utils.data as data
from sklearn.metrics import accuracy_score, precision_score

train_dataset = data.TensorDataset(torch.tensor(X_train, dtype=torch.float32).permute(0, 3, 1, 2), torch.tensor(y_train, dtype=torch.long))
train_loader = data.DataLoader(train_dataset, batch_size=64, shuffle=True)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 5
best_accuracy = 0.0
best_model_path = "best_model.pth"


for epoch in range(num_epochs):
    start_time = time.time()


    model.train()
    running_loss = 0.0
    all_preds, all_labels = [], []

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criteria(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()


        _, predicted = torch.max(outputs, 1)
        all_preds.extend(predicted.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())


    accuracy = accuracy_score(all_labels, all_preds) * 100
    precision = precision_score(all_labels, all_preds, average="weighted") * 100

    epoch_time = time.time() - start_time


    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), best_model_path)
        print(f"new best model saved with accuracy: {best_accuracy:.2f}%")

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, "
          f"Accuracy: {accuracy:.2f}%, Precision: {precision:.2f}%, Time: {epoch_time:.2f}s")

print(f"training completed. Best accuracy: {best_accuracy:.2f}%")