# Prototypical Networks:
In this notebook, we investigate prototypical networks, which are inherently explainable machine learning (ML) models, applicable to different modalities including imaging. These models work by learning a set of representations for each output class, called prototypes, and calculating the distance of each datapoint from these prototypes to predict its class.

In [None]:
# https://github.com/n0obcoder/NIH-Chest-X-Rays-Multi-Label-Image-Classification-In-Pytorch/tree/master
# https://github.com/cxr-eye-gaze/eye-gaze-dataset
import os
import re
from glob import glob
import torch
import random
import numpy as np
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from imgaug import augmenters as iaa
from sklearn.preprocessing import LabelEncoder
import cv2
from tqdm import tqdm
import pickle
from torch.utils.data.sampler import Sampler

import torch
from torch.nn import functional as F
from torch.nn.modules import Module
from utils.utils import get_roc_auc_score
from utils.utils import prototype_heatmap
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import roc_auc_score, roc_curve, f1_score, recall_score, precision_score
import numpy as np

from torchvision import transforms


import os
import sys
import argparse
import logging
import random
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from torch import optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from datetime import datetime

from model_intepretability.imaging_copy.dataset.dataset import PrototypicalBatchSampler
from sklearn.metrics import roc_auc_score, roc_curve
from torch.nn.parallel import DistributedDataParallel as DDP
from functools import partial
import torch.distributed as dist
from sklearn.metrics import roc_auc_score, roc_curve, f1_score, recall_score, precision_score
from torch.utils.data import DataLoader,SubsetRandomSampler, TensorDataset
from sklearn.model_selection import KFold
from torch.utils.data.distributed import DistributedSampler
import pandas as pd
import torch.nn.functional as F
import pickle
import torchvision.models as models
import torch.distributed as dist


from tqdm import tqdm
from torch.autograd import Variable
plt.rcParams['figure.figsize'] = [25, 10]
import gradio as gr
from glob import glob
from PIL import Image

import argparse
import gc
from pathlib import Path
from utils.utils import get_roc_auc_score
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import pandas as pd
import os

import pickle


In [None]:
def make_parser():
    parser = argparse.ArgumentParser(description='Imaging Explainability')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
    parser.add_argument('--resize', type=int, default=224, help='Resizing images')



    parser.add_argument('--batch_size', type=int, default=32, help='batch size')
    parser.add_argument('--epochs', type=int, default=20, help='number of epochs')
    parser.add_argument('--lr', type=float, default=1e-3, help='initial learning rate')
    parser.add_argument('--scheduler', default=False, action='store_true', help='[USE] scheduler')
    parser.add_argument('--step_size', type=int, default=5, help='scheduler step size')


    parser.add_argument('--dropout', type=float, default=0.5, help='dropout')



    parser.add_argument('--rseed', type=int, default=42, help='Seed for reproducibility')
    parser.add_argument('--weight_decay', type=int, default=1e-3, help='Seed for reproducibility')
    parser.add_argument('--num_epochs', type=int, default=20, help='Seed for reproducibility')
    parser.add_argument('-its', '--iterations',
                        type=int,
                        help='number of episodes per epoch, default=100',
                        default=100)
    parser.add_argument('-cTr', '--classes_per_it_tr',
                        type=int,
                        help='number of random classes per episode for training, default=60',
                        default=15)
    parser.add_argument('-nsTr', '--num_support_tr',
                        type=int,
                        help='number of samples per class to use as support for training, default=5',
                        default=5)
    parser.add_argument('-nqTr', '--num_query_tr',
                        type=int,
                        help='number of samples per class to use as query for training, default=5',
                        default=5)
    parser.add_argument('-cVa', '--classes_per_it_val',
                        type=int,
                        help='number of random classes per episode for validation, default=5',
                        default=5)
    parser.add_argument('-nsVa', '--num_support_val',
                        type=int,
                        help='number of samples per class to use as support for validation, default=5',
                        default=5)

    parser.add_argument('-nqVa', '--num_query_val',
                        type=int,
                        help='number of samples per class to use as query for validation, default=15',
                        default=15)
    return parser

## Data:

In this project, we use the "NIH" dataset, which is publicly accessible at [here](https://www.kaggle.com/datasets/nih-chest-xrays/data). It contains 112120 chest X-rays corresponding to 14 different conditions, e.g., Pneumonia. The goal is to train a predictive model to classify each image to one or multiple conditions.

In [None]:
def read_image(image_path):
    image = cv2.imread(image_path)
    image = image/np.max(image)
    return image



class XrayDataset(Dataset):
    def __init__(self, csv_file, image_path_name):

        self.path_name = image_path_name
        self.csv_file = csv_file
        self.the_chosen, self.all_classes, self.all_classes_dict = self.choose_the_indices()
        self.csv_file["numeric_targets"] = self.csv_file['Finding Labels'].apply(lambda x: self.get_tagets(x))
    def __len__(self):

        return len(self.csv_file)
    def choose_the_indices(self):

        max_examples_per_class = 10000 # its the maximum number of examples that would be sampled in the training set for any class
        the_chosen = []
        all_classes = {}
        length = len(self.csv_file)
        print('\nSampling the huuuge training dataset')
        for i in tqdm(list(np.random.choice(range(length),length, replace = False))):

            temp = str.split(self.csv_file.iloc[i, :]['Finding Labels'], '|')

            # special case of ultra minority hernia. we will use all the images with 'Hernia' tagged in them.
            if 'Hernia' in temp:
                the_chosen.append(i)
                for t in temp:
                    if t not in all_classes:
                        all_classes[t] = 1
                    else:
                        all_classes[t] += 1
                continue

            # choose if multiple labels
            if len(temp) > 1:
                bool_lis = [False]*len(temp)
                # check if any label crosses the upper limit
                for idx, t in enumerate(temp):
                    if t in all_classes:
                        if all_classes[t]< max_examples_per_class: # 500
                            bool_lis[idx] = True
                    else:
                        bool_lis[idx] = True
                # if all lables under upper limit, append
                if sum(bool_lis) == len(temp):
                    the_chosen.append(i)
                    # maintain count
                    for t in temp:
                        if t not in all_classes:
                            all_classes[t] = 1
                        else:
                            all_classes[t] += 1
            else:        # these are single label images
                for t in temp:
                    if t not in all_classes:
                        all_classes[t] = 1
                    else:
                        if all_classes[t] < max_examples_per_class: # 500
                            all_classes[t] += 1
                            the_chosen.append(i)



        '''
        if len(the_chosen) != len(set(the_chosen)):
            print('\nGadbad !!!')
            print('and the difference is: ', len(the_chosen) - len(set(the_chosen)))
        else:
            print('\nGood')
        '''
        with open('all_classes.pkl', 'wb') as file:
            pickle.dump(all_classes, file)
        return the_chosen, sorted(list(all_classes)), all_classes

    def get_tagets(self,row):
        labels = str.split(row, '|')

        target = torch.zeros(len(self.all_classes))
        for lab in labels:
            lab_idx = self.all_classes.index(lab)
            target[lab_idx] = 1
        return target
    def get_image(self, idx):
        # -- Query the index location of the required file

        image_name = self.csv_file.loc[idx,'Image Index']

        image_path = glob(os.path.join(self.path_name, '**', image_name), recursive=True)[0]
        image = read_image(image_path)
        if len(image.shape) == 2: image = np.expand_dims(image, axis=-1)

        labels = str.split(self.csv_file.loc[idx,'Finding Labels'], '|')

        target = torch.zeros(len(self.all_classes))
        for lab in labels:
            lab_idx = self.all_classes.index(lab)
            target[lab_idx] = 1
        mean = [0.485, 0.456, 0.406]
        std = [0.229, 0.224, 0.225]
        input_size = 224
        rseed = 42
        seq = iaa.Sequential([iaa.Resize((input_size, input_size))])
        image_transform = transforms.Compose([seq.augment_image, transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
        image = image_transform(image)
        return image.float(), target


    def num_sort(self, filename):
        not_num = re.compile("\D")
        return int(not_num.sub("", filename))

    def __getitem__(self, idx):
        image, label = self.get_image(idx)
        return image,label


In [None]:
# extracted from: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/src/protonet.py
def conv_block(in_channels, out_channels):
    '''
    returns a block conv-bn-relu-pool
    '''
    return nn.Sequential(
        nn.Conv2d(in_channels, out_channels, 3, padding=1),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )

In [None]:
# https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/tree/master

class PrototypicalBatchSampler(object):
    '''
    PrototypicalBatchSampler: yield a batch of indexes at each iteration.
    This version supports multi-label datasets, where each sample may belong to multiple classes.
    '''

    def __init__(self, labels, classes_per_it, num_samples, iterations):
        '''
        Initialize the PrototypicalBatchSampler object.
        Args:
        - labels: binary matrix (n_samples x n_classes), where each row represents the labels of a sample.
        - classes_per_it: number of random classes for each iteration.
        - num_samples: number of samples for each iteration for each class (support + query).
        - iterations: number of iterations (episodes) per epoch.
        '''
        super(PrototypicalBatchSampler, self).__init__()
        self.labels = labels  # Binary matrix of size (n_samples x n_classes)
        self.classes_per_it = classes_per_it
        self.sample_per_class = num_samples
        self.iterations = iterations

        # Determine the number of classes and create mappings
        self.num_classes = 15
        self.classes = [torch.tensor(i) for i in range(self.num_classes)]
        # Create a dictionary that maps each class to the indices of samples belonging to it
        self.class_to_indices = {c.item(): [] for c in self.classes}
        for sample_idx, label_vec in enumerate(self.labels):

            for c in torch.nonzero(label_vec).squeeze(1):  # Get active classes for the sample
                self.class_to_indices[c.item()].append(sample_idx)

        # Convert lists to tensors for efficient indexing

        for c in self.class_to_indices:
            self.class_to_indices[c] = torch.tensor(self.class_to_indices[c])

    def __iter__(self):
        '''
        Yield a batch of indices.
        '''
        spc = self.sample_per_class  # Samples per class
        cpi = self.classes_per_it    # Classes per iteration

        for _ in range(self.iterations):
            batch_indices = []
            # Randomly sample `cpi` classes
            sampled_classes = torch.randperm(self.num_classes)[:cpi]

            for c in sampled_classes:
                class_indices = self.class_to_indices[c.item()]
                if len(class_indices) >= spc:
                    # Randomly select `spc` samples from this class
                    sampled_indices = class_indices[torch.randperm(len(class_indices))[:spc]]
                else:
                    # Handle rare classes with fewer samples
                    sampled_indices = class_indices
                batch_indices.extend(sampled_indices.tolist())

            # Shuffle the batch indices to ensure randomness
            batch_indices = torch.tensor(batch_indices)
            batch_indices = batch_indices[torch.randperm(len(batch_indices))]
            for idx in batch_indices:
                yield idx.item()


    def __len__(self):
        '''
        Return the number of iterations (episodes) per epoch.
        '''
        return self.iterations

In [None]:
#  extracted from: https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/src/prototypical_loss.py

class PrototypicalLoss(Module):
    '''
    Loss class deriving from Module for the prototypical loss function defined below
    '''
    def __init__(self, n_support):
        super(PrototypicalLoss, self).__init__()
        self.n_support = n_support

    def forward(self, input, target):
        return prototypical_loss(input, target, self.n_support)


def euclidean_dist(x, y):
    '''
    Compute euclidean distance between two tensors
    '''
    # x: N x D
    # y: M x D
    n = x.size(0)
    m = y.size(0)
    d = x.size(1)
    if d != y.size(1):
        raise Exception

    x = x.unsqueeze(1).expand(n, m, d)
    y = y.unsqueeze(0).expand(n, m, d)

    return torch.pow(x - y, 2).sum(2)



def prototypical_loss(original_image,inputs, target, n_support):
    """
    Compute the barycentres by averaging the features of n_support
    samples for each class in target, computes then the distances from each
    samples' features to each one of the barycentres, computes the
    log_probability for each n_query samples for each one of the current
    classes, of belonging to a class c. Loss and accuracy are then computed
    and returned.

    Adjusted for multi-label datasets like NIH.

    Args:
    - input: the model output for a batch of samples (batch_size x feature_dim).
    - target: binary matrix (batch_size x num_classes), where each row is a multi-hot vector.
    - n_support: number of samples to use when computing barycentres for each class.
    """
    k = 0
    validation_estimated = []
    validation_true = []

    target_cpu = target
    input_cpu = inputs

    # Find active classes in the batch
    active_classes = torch.nonzero(target_cpu.sum(0)).squeeze(1)
    n_classes = len(active_classes)

    prototypes = []
    query_idxs = []

    # Compute prototypes for each active class

    for c in active_classes:
        # Get indices for samples belonging to class `c`

        class_idxs = torch.nonzero(target_cpu[:, c]).squeeze(1)

        # Separate support and query samples
        support_idxs = class_idxs[:n_support]
        query_idxs_c = class_idxs[n_support:]

        if len(support_idxs) > 0:
            # Compute class prototype
            class_prototype = input_cpu[support_idxs].mean(0)
            prototypes.append(class_prototype)
        else:
            print("No support samples for class {}".format(c))
        # Add query indices
        query_idxs.extend(query_idxs_c.tolist())

    if not prototypes or not query_idxs:
        # Handle edge case where no valid prototypes or queries are available

        return torch.tensor(0.0, device=inputs.device), 0.0  # Loss and AUC as placeholders

    prototypes = torch.stack(prototypes)  # Shape: (n_classes, feature_dim)
    query_samples = input_cpu[query_idxs]  # Shape: (n_query, feature_dim)

    query_samples = F.normalize(query_samples, p=2, dim=1)
    prototypes = F.normalize(prototypes, p=2, dim=1)
    # Compute distances from queries to prototypes
    dists = euclidean_dist(query_samples, prototypes)

    # prototype_heatmap(original_image[query_idxs],query_samples,prototypes)
    # Compute log probabilities
    log_p_y = F.log_softmax(-dists, dim=1)

    # Multi-label target construction for queries
    query_targets = target_cpu[query_idxs][:, active_classes]



    loss_val = F.binary_cross_entropy_with_logits(-dists, query_targets.float())
    # Multi-label accuracy
    preds = (log_p_y > 0).float()  # Threshold at 0


    # Compute AUC
    validation_estimated = torch.exp(log_p_y).detach().cpu().numpy()
    validation_true = query_targets.detach().cpu().numpy()

    auc_val = get_roc_auc_score(validation_true, validation_estimated)

    return loss_val, auc_val,query_targets,torch.exp(log_p_y)

In [None]:
import matplotlib.pyplot as plt
import pickle
from sklearn.metrics import roc_auc_score, roc_curve, f1_score, recall_score, precision_score
import numpy as np
import torch.nn.functional as F
import torch
from torchvision import transforms



def get_roc_auc_score(y_true, y_probs):
    '''
    Uses roc_auc_score function from sklearn.metrics to calculate the micro ROC AUC score for a given y_true and y_probs.
    '''

    with open('all_classes.pkl', 'rb') as all_classes:
        all_classes = pickle.load(all_classes)

    NoFindingIndex = all_classes.get('No Finding', -1)
    class_roc_auc_list = []
    useful_classes_roc_auc_list = []

    for i in range(y_true.shape[1]):
        if len(np.unique(y_true[:, i])) > 1:
            class_roc_auc = roc_auc_score(y_true[:, i], y_probs[:, i])
            class_roc_auc_list.append(class_roc_auc)
            if i != NoFindingIndex:
                useful_classes_roc_auc_list.append(class_roc_auc)
    return np.mean(np.array(useful_classes_roc_auc_list))

def prototype_heatmap(xray_image,image_reps,prototype_rep):


    similarities = torch.mm(image_reps, prototype_rep.T)
    assigned_prototypes = prototype_rep[torch.argmax(similarities, dim=1),:]

    for i, (image,image_rep) in enumerate(zip(xray_image,image_reps)):
        assigned_proto = prototype_rep[torch.argmax(similarities[i,:]),:].unsqueeze(0)
        image_rep = F.normalize(image_rep.unsqueeze(0).float(), p=2, dim=1)  # Normalize along the channels
        assigned_proto = F.normalize(assigned_proto.float(), p=2, dim=1)  # Normalize along the channels

        # Calculate cosine similarity at each pixel
        similarity_map = F.cosine_similarity(image_rep, assigned_proto, dim=0)  # Similarity along the channel dimension

        similarity_map = (similarity_map - similarity_map.min()) / (similarity_map.max() - similarity_map.min())
        resized_similarity_map = torch.nn.functional.interpolate(
            similarity_map.reshape(112,112).unsqueeze(0).unsqueeze(0),
            size=image.size()[-2:],  # Height and Width of the image
            mode='bilinear',
            align_corners=False
            ).squeeze()
        plt.imshow(image[0].cpu().numpy(), cmap='gray')
        plt.imshow(resized_similarity_map.cpu().detach().numpy(), cmap='jet', alpha=0.5)
        plt.colorbar(label='Similarity')
        plt.axis('off')
        plt.show()
        plt.savefig("prototype_heatmap.png")
        break

In [None]:
class ProtoNet(nn.Module):
    '''
    Model as described in the reference paper,
    source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
    '''
    def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
        super(ProtoNet, self).__init__()
        self.encoder = nn.Sequential(
            conv_block(x_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, hid_dim),
            conv_block(hid_dim, z_dim),
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)

In [None]:
def prototype_heatmap(xray_image,image_reps,prototype_rep):


    similarities = torch.mm(image_reps, prototype_rep.T)
    assigned_prototypes = prototype_rep[torch.argmax(similarities, dim=1),:]

    for i, (image,image_rep) in enumerate(zip(xray_image,image_reps)):
        assigned_proto = prototype_rep[torch.argmax(similarities[i,:]),:].unsqueeze(0)
        image_rep = F.normalize(image_rep.unsqueeze(0).float(), p=2, dim=1)  # Normalize along the channels
        assigned_proto = F.normalize(assigned_proto.float(), p=2, dim=1)  # Normalize along the channels

        # Calculate cosine similarity at each pixel
        similarity_map = F.cosine_similarity(image_rep, assigned_proto, dim=0)  # Similarity along the channel dimension

        similarity_map = (similarity_map - similarity_map.min()) / (similarity_map.max() - similarity_map.min())
        resized_similarity_map = torch.nn.functional.interpolate(
            similarity_map.reshape(112,112).unsqueeze(0).unsqueeze(0),
            size=image.size()[-2:],  # Height and Width of the image
            mode='bilinear',
            align_corners=False
            ).squeeze()
        plt.imshow(image[0].cpu().numpy(), cmap='gray')
        plt.imshow(resized_similarity_map.cpu().detach().numpy(), cmap='jet', alpha=0.5)
        plt.colorbar(label='Similarity')
        plt.axis('off')
        plt.show()
        plt.savefig("prototype_heatmap.png")
        break

In [None]:
def train_prototype(args, train_ds,test_ds):



    global f1_list,auc_list, precision_list, recall_list
    setup()
    init_fn = partial(
        worker_init_fn,
        num_workers=args.num_workers,
        rank=dist.get_rank(),
        seed=args.rseed,
    )

    test_sampler = DistributedSampler(test_ds, shuffle=True)

    test_dl = DataLoader(test_ds,sampler=test_sampler, batch_size=args.batch_size,worker_init_fn=init_fn)

    # setup()
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
    torch.cuda.empty_cache()
    device_id = torch.cuda.current_device()



    for fold, (train_idx,val_idx) in enumerate(splits.split(np.arange(len(train_ds)))):

        print("fold:", fold)

        print('Fold {}'.format(fold + 1))

        train_subset = Subset(train_ds, train_idx)
        val_subset = Subset(train_ds, val_idx)
        train_sampler = DistributedSampler(train_subset, shuffle=True)
        valid_sampler = DistributedSampler(val_subset, shuffle=True)
        classes_per_it_tr = args.classes_per_it_tr
        num_samples_tr = args.num_support_tr + args.num_query_tr

        classes_per_it_val = args.classes_per_it_val
        num_samples_val = args.num_support_val + args.num_query_val

        classes = train_ds.csv_file["numeric_targets"]

        train_sampler = PrototypicalBatchSampler(labels=classes,
                                    classes_per_it=classes_per_it_tr,
                                    num_samples=num_samples_tr,
                                    iterations=args.iterations)
        valid_sampler = PrototypicalBatchSampler(labels=classes,
                                    classes_per_it=classes_per_it_val,
                                    num_samples=num_samples_val,
                                    iterations=args.iterations)

        init_fn = partial(worker_init_fn,num_workers=args.num_workers,rank=dist.get_rank(),seed=args.rseed)
        train_dl = DataLoader(train_ds, sampler=train_sampler,batch_size=args.batch_size,worker_init_fn=init_fn,pin_memory=False,drop_last=True,num_workers=args.num_workers)
        valid_dl = DataLoader(train_ds, sampler=valid_sampler, batch_size=args.batch_size,worker_init_fn=init_fn,pin_memory=False,drop_last=True,num_workers=args.num_workers)
        model2 = ProtoNet()
        model2 = model2.cuda(device_id)
        # model2 = DDP(model2, device_ids=[device_id])#,find_unused_parameters=True)


        optimizer = optim.AdamW(model2.parameters(), lr=args.lr,weight_decay=args.weight_decay)
        try:
            model2, optimizer, start_epoch, _ = load_checkpoint(model2, optimizer,"output_weight/proto.pth")
            print(f"Resuming from epoch {start_epoch + 1}")
        except FileNotFoundError:
            print("No checkpoint found, starting frsom scratch")

        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, gamma=0.995, step_size=1)
        optimizer.zero_grad()


        train_auc = []
        val_auc = []
        for epoch in range(args.num_epochs):

            model2.train()
            num_batches=0

            k=0
            tr_iter = iter(train_dl)

            for batch in tqdm(tr_iter):
                images,labels = batch
                optimizer.zero_grad()
                images = images.cuda(device_id, non_blocking=True)
                labels = labels.cuda(device_id, non_blocking=True)
                image_rep= model2(images.float())
                active_classes = torch.nonzero(labels.sum(0)).squeeze(1)
                loss, auc,target,prob = prototypical_loss(images,image_rep, target=labels,
                                n_support=args.num_support_tr)

                train_auc.append(auc)

                loss = Variable(loss, requires_grad = True)
                loss.backward()

                optimizer.step()
                if (num_batches+1)%100==0:
                    save_checkpoint(model2, optimizer, epoch, loss.item(), file_path=f"output_weight/proto_backup.pth")

                if (num_batches+1)%200==0:
                    save_checkpoint(model2, optimizer, epoch, loss.item(), file_path=f"output_weight/proto.pth")

                break
            avg_auc = np.mean(train_auc[-args.iterations:])   #??????

            print('Avg Train AUC: {}'.format(avg_auc))
            model2.eval()

            with torch.set_grad_enabled(False):

                k = 0
                for val_batches,(images,labels) in enumerate(valid_dl):

                    labels =  labels.cuda(device_id, non_blocking=True)


                    images = images.cuda(device_id, non_blocking=True)
                    image_rep = model2(images)

                    _, auc,target,prob = prototypical_loss(images,image_rep, target=labels,
                                n_support=args.num_support_val)

                    val_auc.append(auc)#.item())

                    k += prob.shape[0]

                avg_val_auc = np.mean(val_auc[-args.iterations:])


            print("epoch",epoch,":","train_AUC:",avg_auc,"val_AUC",avg_val_auc)
            model2.eval()
        with torch.set_grad_enabled(False):


            k = 0

            avg_acc = []
            for epoch in range(10):
                for images,labels in test_dl:


                    images, labels = images.cuda(device_id, non_blocking=True), labels.cuda(device_id, non_blocking=True)

                    image_rep = model2(images)
                    _, auc,target,prob = prototypical_loss(images,image_rep, target=labels,
                                n_support=args.num_support_val)
                    avg_acc.append(auc.item())


            avg_acc = np.mean(avg_acc)
            print('Test Acc: {}'.format(avg_acc))








    logging.info('Finished training.')

    dist.destroy_process_group()
    return 0

In [None]:
args = make_parser().parse_args()
auc_list , precision_list, recall_list, f1_list = [], [], [], []
random.seed(args.rseed)
np.random.seed(args.rseed)
torch.manual_seed(args.rseed)
torch.cuda.manual_seed(args.rseed)
torch.cuda.manual_seed_all(args.rseed)

image_path = "/datasets/nih-chest-xrays"

csv_file = pd.read_csv(os.path.join(image_path,"Data_Entry_2017.csv"))
test_split = os.path.join(image_path,"test_list.txt")
train_val_split = os.path.join(image_path,"train_val_list.txt")
with open(train_val_split, 'r') as f:
    train_val_images = f.read().splitlines()
with open(test_split, 'r') as f:
    test_images = f.read().splitlines()

train_df = csv_file[csv_file['Image Index'].isin(train_val_images)]

test_df = csv_file[csv_file['Image Index'].isin(test_images)]
train_df.reset_index(drop=True, inplace=True)
test_df.reset_index(drop=True, inplace=True)
train_ds = XrayDataset(train_df, image_path)
test_ds = XrayDataset(test_df, image_path)
train_prototype(args, train_ds,test_ds)

## References
[[1] Prototypical networks for few-shot learning](https://proceedings.neurips.cc/paper_files/paper/2017/hash/cb8da6767461f2812ae4290eac7cbc42-Abstract.html)