In [None]:
import numpy as np
import pandas as pd
import os
import tqdm
import torch
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from torchvision.models import alexnet, vgg11, vgg19
from scipy.sparse import load_npz
from sklearn.metrics import roc_curve, auc
import pickle

In [12]:
# CUSTOM DATASET

def create_image(img_path):
    image = load_npz(img_path)
    image = image.toarray()
    image = np.resize(image, (1, 225, 225))

    return image

class PhysicsImageDataset(Dataset):
    def __init__(self, file_dir='path/to/npz', subset='train', approach=None, transform=None, target_transform=None):
        self.files = []
        self.labels = []
        self.transform = transform
        self.approach = approach

        txt_dir = 'path/to/txt'
        txt_file = f"{txt_dir}/{subset}_files.txt"

        with open(txt_file, 'r') as file:
            for line in file:
                fields = line.strip().split(', ')
                filenames = fields[:3]
                label = int(fields[-1])
                file_paths = (
                    os.path.join(file_dir, "0", "ATMO" if label == 0 else "PDK", filenames[0]),
                    os.path.join(file_dir, "1", "ATMO" if label == 0 else "PDK", filenames[1]),
                    os.path.join(file_dir, "2", "ATMO" if label == 0 else "PDK", filenames[2])
                )
                self.files.append(file_paths)
                self.labels.append(label)


    def __len__(self):
        return len(self.files)


    def __getitem__(self, idx):
        img_paths = self.files[idx]
        images = []
        for img_path in img_paths:
            image = create_image(img_path)
            if self.transform:
                image = self.transform(image)
            if self.approach == 'EF': # Early Fusion
                images.append(image.squeeze(0))
            else:
                images.append(image)

        if self.approach == 'EF':
            images = np.stack(images, axis = 0)

        label = self.labels[idx]
        return images, label, img_paths

In [5]:
def create_network(model_path, plane = None):
    if NETWORK == "alexnet":
        net = alexnet(weights=None)
    elif NETWORK == "vgg11":
        net = vgg11(weights=None)
    elif NETWORK == "vgg19":
        net = vgg19(weights=None)
    if plane != None:    
        if plane == 0:
            net = torch.load(model_path[NETWORK][0])
        elif plane == 1:
            net = torch.load(model_path[NETWORK][1])
        else:
            net = torch.load(model_path[NETWORK][2])
    else:
        net = torch.load(model_path[NETWORK])
            
    return net # Return the model with loaded weights

In [6]:
def misclassified_paths_df(misclassified_paths, NETWORK):
    data = []
    for file_path in misclassified_paths:
        _, filename = os.path.split(file_path)
        base = filename.replace('.extracted.npz', '')
        base = base.split('_')

#       ['files', 'PDK', '2', 'larcv', 'plane0', '185']
        atmonu_pdk = base[1]
        mc_sim = int(base[2])
        plane = int(base[-2][-1])
        event = int(base[-1])
        
        # Append the extracted information as a dictionary to the data list
        data.append({
            'filename': filename,
            'network' : NETWORK,
            'atmonu_pdk': atmonu_pdk,
            'mc_sim': mc_sim,
            'plane': plane,
            'event': event
        })

    return pd.DataFrame(data) # The resulting DataFrame contains information on misclassified paths

In [14]:
batch_size = 128
workers = 4

# Late Fusion

In [16]:
def aggregate_outputs(list_of_outputs, METHOD = "mean_probability", threshold = 0.5):

     #for output in list_of_outputs:
         #print(output.shape)

     outputs = torch.stack(list_of_outputs)
     #print(outputs.shape)


     if METHOD == "voting":
        # Perform voting by taking the maximum value across classes
         _, predicted = torch.max(outputs, 2)
         # Average the predicted classes across all outputs and apply the threshold to determine the final prediction
         predictions = (predicted.mean(dim=0, dtype=float)>= threshold)
         return predictions

     elif METHOD == "mean_probability":
        # Apply the softmax function to each output to get probabilities (over the classes)
         probabilities = torch.nn.functional.softmax(outputs, dim=2)
         final_probs = probabilities.mean(dim=0)
         # Check if the probability for class 1 is greater than or equal to the threshold
         predictions = (final_probs[:, 1] >= threshold)
         return predictions, final_probs[:, 1] # Return both the predictions and the final average probability for class 1

In [17]:
def calculate_roc_latef(nets, testloader, pred_threshold):
    
    correct = 0
    misclassified_test_paths = []
    
    predictions = []
    labels = []
    probs = []
    
    with torch.no_grad():
        for imgs_lists, true_labels, img_paths in tqdm.tqdm(testloader, desc=f'Testing {NETWORK} - Late Fusion'):
            imgs0, imgs1, imgs2 = imgs_lists
            imgs0, imgs1, imgs2, true_labels = imgs0.cuda().float(), imgs1.cuda().float(), imgs2.cuda().float(), true_labels.cuda()
            
            outputs = [] 
            for imgs, net in zip([imgs0, imgs1, imgs2], nets):
                outputs.append(net(imgs)) # Store the output of each network
            
            method = 'mean_probability' # or 'voting'
            aggregated_predictions, aggregated_probabilities = aggregate_outputs(outputs, method, pred_threshold)
            predictions.extend(aggregated_predictions.cpu().numpy())
            labels.extend(true_labels.cpu().numpy())
            probs.extend(aggregated_probabilities.cpu().numpy())
            
            indices_wrong = torch.nonzero(aggregated_predictions != true_labels).squeeze().tolist()
            if type(indices_wrong) is int:
                indices_wrong = [indices_wrong]
            for idx in indices_wrong:
                for img_path in img_paths:
                    misclassified_test_paths.append(img_path[idx])
    
            correct += (true_labels == aggregated_predictions).sum()

    # Calculate the false positive rate (fpr) and true positive rate (tpr) for ROC curve
    fpr, tpr, _ = roc_curve(labels, probs)
    # Calculate the area under the ROC curve (AUC)
    roc_auc = auc(fpr, tpr)
    
    # Calculate accuracy as the percentage of correct predictions
    accuracy = 100*correct/len(labels)      
    print(f"{NETWORK.capitalize()} - Late Fusion - Test accuracy: {accuracy:.2f}\n")

    # Return all relevant metrics and the misclassified test paths
    return predictions, labels, probs, fpr, tpr, roc_auc, misclassified_test_paths

In [22]:
# Path for the saved models
paths_lf = {
    'alexnet' : ["alexnet_net0_epoch.pth", "alexnet_net1_epoch.pth", "alexnet_net2_epoch.pth"],
    'vgg11' : ["vgg11_net0_epoch.pth", "vgg11_net1_epoch.pth", "vgg11_net2_epoch.pth"],
    'vgg19' : ["vgg19_net0_epoch.pth", "vgg19_net1_epoch.pth", "vgg19_net2_epoch.pth"],
}

In [None]:
test_dataset = PhysicsImageDataset(subset='test')
print("Test dataset created")

In [9]:
#test_dataset.files[-3:]

In [15]:
# Create DataLoader for test dataset
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True, num_workers=workers)

In [13]:
#print('Counting test labels')
#class0 = 0
#class1 = 0
#for imgs, labels, img_paths in tqdm.tqdm(testloader):
#    class0 += (labels == 0).sum().item()
#    class1 += (labels == 1).sum().item()
#tot = class0 + class1
#print(f'Testloader: AtmoNu: {class0} {100*class0/tot:.2f}%, PDK: {class1} {100*class1/tot:.2f}%')

## alexnet

In [20]:
NETWORK = 'alexnet'

In [None]:
net0 = create_network(paths_lf, 0)
net1 = create_network(paths_lf, 1)
net2 = create_network(paths_lf, 2)
for net in [net0, net1, net2]:
    net = net.cuda().eval()

In [None]:
# Testing and saving metrics
(predictions_alexnet, labels_alexnet, probs_alexnet, fpr_alexnet, tpr_alexnet, roc_auc_alexnet,
misclassified_test_paths_alexnet) = calculate_roc_latef([net0, net1, net2], testloader, pred_threshold = 0.5)

In [24]:
alex_test_df = misclassified_paths_df(misclassified_test_paths_alexnet, 'alexnet')

In [25]:
# Saving the results using pickle
with open('roc_results_alexnet_late_fusion.pkl', 'wb') as file:
    pickle.dump({
        'predictions': predictions_alexnet,
        'labels': labels_alexnet,
        'probs': probs_alexnet,
        'fpr': fpr_alexnet,
        'tpr': tpr_alexnet,
        'roc_auc': roc_auc_alexnet,
        'misclassified_test_paths': misclassified_test_paths_alexnet
    }, file)

## vgg11

## vgg19

## Misclassified Images

In [None]:
wrong_images_late_fusion_df = pd.concat([alex_test_df, vgg11_test_df, vgg19_test_df], ignore_index=True)

In [None]:
# Saving the paths in csv format
wrong_images_late_fusion_df.to_csv('wrong_images_late_fusion_from_testing.csv', index=False)

# Early Fusion

In [None]:
def prediction(outputs, METHOD = 'mean_probability', pred_threshold = 0.5):

    if METHOD == "voting":
        _, predicted = torch.max(outputs, 1)
        return predicted

    elif METHOD == "mean_probability":
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted = probabilities[:, 1] >= pred_threshold
        return predicted, probabilities[:, 1]

In [None]:
def calculate_roc_earlyf(net, testloader, pred_threshold):
    
    correct = 0
    misclassified_test_paths = []

    predictions = []
    labels = []
    probs = []

    with torch.no_grad():
        for imgs, true_labels, img_paths in tqdm.tqdm(testloader, desc=f'Testing {NETWORK} - Early Fusion'):
            imgs, true_labels = imgs.cuda().float(), true_labels.cuda()
            outputs = net(imgs)
            
            method = 'mean_probability' # or 'voting'
            predicted, prob = prediction(outputs, method, pred_threshold)
            predictions.extend(predicted.cpu().numpy())
            labels.extend(true_labels.cpu().numpy())
            probs.extend(prob.cpu().numpy())
            
            indices_wrong = torch.nonzero(predicted != true_labels).squeeze().tolist()
            if type(indices_wrong) is int:
                indices_wrong = [indices_wrong]
            for idx in indices_wrong:
                for img_path in img_paths:
                    misclassified_test_paths.append(img_path[idx])
 
            correct += (predicted == true_labels).sum()

    fpr, tpr, _ = roc_curve(labels, probs)
    roc_auc = auc(fpr, tpr)
    
    print(f"{NETWORK.capitalize()} - Early fusion - Test accuracy: {100*correct/len(labels):.2f}\n")
    print(f"{NETWORK.capitalize()} evaluation completed\n")


    return predictions, labels, probs, fpr, tpr, roc_auc, misclassified_test_paths

In [None]:
paths_ef = {
    'alexnet' : "alexnet_epoch.pth",
    'vgg11' : "vgg11_epoch.pth",
    'vgg19' : "vgg19_epoch.pth"
}

In [None]:
test_dataset = PhysicsImageDataset(subset='test')
print("Test dataset created")

In [None]:
#test_dataset.files[-3:]

In [None]:
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                          shuffle=True, num_workers=workers)

## alexnet

In [None]:
NETWORK = 'alexnet'

In [None]:
net = create_network(model_path=paths_ef)
net.cuda().eval()

In [None]:
(predictions_alexnet, labels_alexnet, probs_alexnet, fpr_alexnet, tpr_alexnet, roc_auc_alexnet,
misclassified_test_paths_alexnet) = calculate_roc_earlyf(net, testloader, pred_threshold = 0.5)

In [None]:
alex_test_df = misclassified_paths_df(misclassified_test_paths_alexnet, 'alexnet')

In [None]:
with open('roc_results_alexnet_early_fusion.pkl', 'wb') as file:
    pickle.dump({
        'predictions': predictions_alexnet,
        'labels': labels_alexnet,
        'probs': probs_alexnet,
        'fpr': fpr_alexnet,
        'tpr': tpr_alexnet,
        'roc_auc': roc_auc_alexnet,
        'misclassified_test_paths': misclassified_test_paths_alexnet
    }, file)

## vgg11

## vgg19

## Misclassified Images