# Multiple-Instance Learning
* https://github.com/rhgao/Deep-MIML-Network/blob/master/models/MIML.py
* https://github.com/MSKCC-Computational-Pathology/MIL-nature-medicine-2019/blob/master/MIL_train.py
* https://github.com/binli123/dsmil-wsi/blob/master/attention_map.py

* https://github.com/jusiro/mil_histology/tree/main/code

# mil_histology 

## imports

In [22]:
# general
import os
import random
import glob

# data
import pandas as pd
import numpy as np

# plot
from matplotlib import pyplot as plt

# image
import skimage.transform
import skimage.util
import cv2
from PIL import Image

# torch 
import torch
import torchvision



np.random.seed(19)
random.seed(19)
torch.manual_seed(19)

<torch._C.Generator at 0x232706ea110>

## Settings

In [28]:
class Arguments():
    def __init__(self):
        # self.dir_images = ""
        self.dir_csv_data = "C:/Users/Prinzessin/projects/decentnet/datasceyence/data_prep/mil*.csv"
        self.dir_results = "results"
        self.criterion = "z"
        self.experiment_name = "tmp1"
        self.classes = ["glaucoma", "faz", "onh", "dr", "healthy"]
        self.proportions = ["Primary", "Secondary"]
        self.input_shape = [3, 224, 224]
        self.epochs = 100
        self.aggregation = "max"
        self.mode = "instance"
        self.include_background = True
        self.lr = 1*1e-2
        self.pMIL = False
        
        self.alpha_ce = 1
        self.margin = 0.
        self.alpha_ic = 1
        self.alpha_pc = 1
        self.alpha_H = 0
        self.t_ic = 15
        self.t_pc = 5
        self.data_augmentation = True
        self.iterations = 3
        
        self.early_stopping = True
        self.scheduler = True
        self.virtual_batch_size = 1

## dataset

In [29]:
class MILDataset(object):

    def __init__(self, csv_data, classes, bag_id='bag_name', input_shape=(3, 224, 224),
                 data_augmentation=False, channel_first=True,
                 pMIL=False, proportions=None, only_primary=False, dataframe_instances=False):

        """Dataset object for MIL.
            Dataset object which aims to organize images and labels from a dataset in the form of bags.
        Args:
          x dir_images: (h, w, channels)
          csv_data: pandas dataframe with ground truth information.
                      Each bag is one raw, with 'bag_name' as identifier.
          classes: list of classes of interest in data_fame (i.e. ['G3', 'G4', 'G5'])
          input_shape: image input shape (channels first).
          data_augmentation: whether to perform data augmentation (True) or not (False).

        Returns:
          MILDataset object
        Last Updates: Julio Silva (19/03/21)
        """

        'Internal states initialization'
        # self.dir_images = dir_images
        self.csv_data = csv_data
        self.classes = classes
        self.bag_id = bag_id
        self.data_augmentation = data_augmentation
        self.input_shape = input_shape
        self.channel_first = channel_first
        self.pMIL = pMIL
        self.proportions = proportions
        # self.images = os.listdir(dir_images)
        self.only_primary = only_primary
        self.dataframe_instances = dataframe_instances

        # Filter patches whose slide is not in the dataframe
        idx = np.in1d([ID.split('_')[0] for ID in self.images], self.csv_data[self.bag_id])
        images = [self.images[i] for i in range(self.images.__len__()) if idx[i]]
        self.images = images

        # Filter slides in the dataframe whose patches are not in the images folder
        self.csv_data = self.csv_data[
            np.in1d(self.csv_data[self.bag_id], [ID.split('_')[0] for ID in images])]

        # Organize bags in the form of dictionary: one key clusters indexes from all instances
        self.D = dict()
        for i, item in enumerate([ID.split('_')[0] for ID in self.images]):
            if item not in self.D:
                self.D[item] = [i]
            else:
                self.D[item].append(i)

        self.y = self.csv_data[self.classes].values
        self.indexes = np.arange(len(self.images))

        

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.indexes)

    def __getitem__(self, index):
        'Generates one sample of data'
        # Select sample
        ID = self.images[self.indexes[index]]


        # Load image
        x = Image.open(os.path.join(self.dir_images, ID))
        x = np.asarray(x)
        # Normalization
        x = self.image_normalization(x)

        # data augmentation
        if self.data_augmentation:
            x_augm = self.image_transformation(x.copy())
        else:
            x_augm = None

        return x, x_augm

    def image_transformation(self, img):

        if self.channel_first:
            img = np.transpose(img, (1, 2, 0))

        if random.random() > 0.5:
            img = np.fliplr(img)
        if random.random() > 0.5:
            img = np.flipud(img)
        if random.random() > 0.5:
            angle = random.random() * 60 - 30
            img = skimage.transform.rotate(img, angle)
        #if random.random() > 0.5:
        #    img = skimage.util.random_noise(img, var=random.random() ** 2)
        #if random.random() > 0.5:
        #    img = img + random.random() - 0.5
        #    img = np.clip(img, 0, 1)

        if self.channel_first:
            img = np.transpose(img, (2, 0, 1))

        return img

    def image_normalization(self, x):
        # image resize
        x = cv2.resize(x, (self.input_shape[1], self.input_shape[2]))
        # intensity normalization
        x = x / 255.0
        # channel first
        if self.channel_first:
            x = np.transpose(x, (2, 0, 1))
        # numeric type
        x.astype('float32')
        return x

    def plot_image(self, x, norm_intensity=False):
        # channel first
        if self.channel_first:
            x = np.transpose(x, (1, 2, 0))
        if norm_intensity:
            x = x / 255.0

        plt.imshow(x)
        plt.axis('off')
        plt.show()

    def cifar10_test_dataset(self, dir_dataset):
        files = os.listdir(dir_dataset)
        files = [iFile for iFile in files if iFile != 'Thumbs.db']

        Y = []
        X = []
        for iFile in files:
            if 'Other' in iFile:
                y = 0
            else:
                y = int(iFile.split('_')[-2][-1])

            # Load image
            x = Image.open(os.path.join(dir_dataset, iFile))
            x = np.asarray(x)
            # Normalization
            x = self.image_normalization(x)

            Y.append(y)
            X.append(x)

        return np.array(X), np.array(Y)

    def ordering_matrix(self, p):

        if not self.only_primary:
            nRestrictions = len(np.where(np.array(p) > 0)[0])
        else:
            nRestrictions = len(np.where(np.array(p) > 0)[0]) - 1

        if nRestrictions <= 0:
            return [np.zeros((1, len(p))), np.zeros((1, len(p)))]

        # p: numpy array with proportion of used classes
        O = np.zeros((nRestrictions, len(p)))

        # Sort proportion values
        indexes = np.flip(np.argsort(p))

        for i in np.arange(0, nRestrictions):
            O[i, indexes[i]] = -1

        # p: numpy array with proportion of used classes
        if nRestrictions > 1:
            O2 = np.zeros((nRestrictions-1, len(p)))
            for i in np.arange(0, nRestrictions-1):
                O2[i, indexes[i]] = -1
                O2[i, indexes[i + 1]] = 1
        else:
            O2 = np.zeros((1, len(p)))

        return [O, O2]


class MILDataGenerator(object):

    def __init__(self, dataset, batch_size=1, shuffle=False, max_instances=512):

        """Data Generator object for MIL.
            Process a MIL dataset object to output batches of instances and its respective labels.
        Args:
          dataset: MIL datasetdataset object.
          batch_size: batch size (number of bags). It will be usually set to 1.
          shuffle: whether to shuffle the bags (True) or not (False).
          max_instances: maximum amount of instances allowed due to computational limitations.

        Returns:
          MILDataGenerator object
        Last Updates: Julio Silva (19/03/21)
        """

        'Internal states initialization'
        self.dataset = dataset
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indexes = np.arange(len(self.dataset.csv_data))
        self.max_instances = max_instances

        self._idx = 0
        self._reset()

    def __len__(self):

        N = len(self.indexes)
        b = self.batch_size
        return N // b + bool(N % b)

    def __iter__(self):

        return self

    def __next__(self):

        # If dataset is completed, stop iterator
        if self._idx >= len(self.dataset.csv_data):
            self._reset()
            raise StopIteration()

        # Get samples of data frame to use in the batch
        df_row = self.dataset.csv_data.iloc[self.indexes[self._idx]]

        # Get bag-level label
        Y = df_row[self.dataset.classes].to_list()
        Y = np.expand_dims(np.array(Y), 0)

        # Get ordering matrix
        if self.dataset.pMIL:
            O = self.dataset.O[self.indexes[self._idx]]

        # Select instances from bag
        ID = list(df_row[[self.dataset.bag_id]].values)[0]
        images_id = self.dataset.D[ID]

        # Memory limitation of patches in one slide
        if len(images_id) > self.max_instances:
            images_id = random.sample(images_id, self.N)
        # Minimum number os patches in a slide (by precaution).
        if len(images_id) < 4:
            images_id.extend(images_id)

        self.instances_indexes = images_id

        # Load images and include into the batch
        X = []
        X_augm = []
        for i in images_id:
            x, x_augm = self.dataset.__getitem__(i)
            X.append(x)
            X_augm.append(x_augm)

        # Update bag index iterator
        self._idx += self.batch_size

        if self.dataset.pMIL:
            if self.dataset.data_augmentation:
                return np.array(X).astype('float32'), np.array(Y).astype('float32'), O, np.array(X_augm).astype('float32')
            else:
                return np.array(X).astype('float32'), np.array(Y).astype('float32'), O, None
        else:
            if self.dataset.data_augmentation:
                return np.array(X).astype('float32'), np.array(Y).astype('float32'), None, np.array(X_augm).astype('float32')
            else:
                return np.array(X).astype('float32'), np.array(Y).astype('float32'), None, None

    def _reset(self):

        if self.shuffle:
            random.shuffle(self.indexes)
        self._idx = 0


## architecture

In [30]:
class MILArchitecture(torch.nn.Module):

    def __init__(self, classes, mode='embedding', aggregation='mean', backbone='VGG19', include_background=False):
        super(MILArchitecture, self).__init__()

        """Data Generator object for MIL.
            CNN based architecture for MIL classification.
        Args:
          classes: 
          mode:
          aggregation: max, mean, attentionMIL, mcAttentionMIL
          backbone:
          include_background:

        Returns:
          MILDataGenerator object
        Last Updates: Julio Silva (19/03/21)
        """

        'Internal states initialization'

        self.classes = classes
        self.mode = mode
        self.aggregation = aggregation
        self.backbone = backbone
        self.include_background = include_background
        self.C = []
        self.prototypical = False

        if self.include_background:
            self.nClasses = len(classes) + 1
        else:
            self.nClasses = len(classes)
        self.eps = 1e-6

        # Backbone
        self.bb = Encoder(pretrained=True, backbone=backbone, aggregation=True)
        
        # Classifiers
        self.classifier = torch.nn.Linear(512, self.nClasses)
            
        # MIL aggregation
        self.milAggregation = MILAggregation(aggregation=aggregation, nClasses=self.nClasses, mode=self.mode)

    def forward(self, images):
        # Patch-Level feature extraction
        features = self.bb(images)

        # if self.mode == 'instance':
        # Classification
        patch_classification = torch.softmax(self.classifier(torch.squeeze(features)), 1)

        # MIL aggregation
        global_classification = self.milAggregation(patch_classification)

        if self.include_background:
            global_classification = global_classification[1:]

        return global_classification, patch_classification, features


class Encoder(torch.nn.Module):

    def __init__(self, pretrained=True, backbone='resnet18', aggregation=False):
        super(Encoder, self).__init__()

        self.aggregation = aggregation
        self.pretrained = pretrained
        self.backbone = backbone

        if backbone == 'resnet18':
            resnet = torchvision.models.resnet18(pretrained=pretrained)
            self.F = torch.nn.Sequential(resnet.conv1,
                                         resnet.bn1,
                                         resnet.relu,
                                         resnet.maxpool,
                                         resnet.layer1,
                                         resnet.layer2,
                                         resnet.layer3,
                                         resnet.layer4)
        elif backbone == 'vgg19':
            vgg19 = torchvision.models.vgg16(pretrained=pretrained)
            self.F = vgg19.features

        # placeholder for the gradients
        self.gradients = None

    def forward(self, x):
        out = self.F(x)

        # register the hook
        h = out.register_hook(self.activations_hook)

        if self.aggregation:
            out = torch.nn.AdaptiveAvgPool2d((1, 1))(out)

        return out

    # method for the gradient extraction
    def get_activations_gradient(self):
        return self.gradients

    # method for the activation exctraction
    def get_activations(self, x):
        return self.features_conv(x)

    # hook for the gradients of the activations
    def activations_hook(self, grad):
        self.gradients = grad


class MILAggregation(torch.nn.Module):
    def __init__(self, aggregation='mean', nClasses=2, mode='embedding'):
        super(MILAggregation, self).__init__()

        """Aggregation module for MIL.
        Args:
          aggregation:

        Returns:
          MILAggregation module for CNN MIL Architecture
        Last Updates: Julio Silva (19/03/21)
        """

        self.mode = mode
        self.aggregation = aggregation
        self.nClasses = nClasses


    def forward(self, feats):

        if self.aggregation == 'max':
            embedding = torch.max(feats, dim=0)[0]
            return embedding
        elif self.aggregation == 'mean':
            embedding = torch.mean(feats, dim=0)
            return embedding



## trainer

In [31]:
class MILTrainer():
    def __init__(self, dir_out, network, lr=1*1e-4, pMIL=False, margin=0, t_ic=10,
                 t_pc=10, alpha_ic=1, alpha_pc=1, alpha_ce=1, id='', early_stopping=False,
                 scheduler=False, virtual_batch_size=1, criterion='auc', alpha_H=0.01):

        self.dir_results = dir_out
        if not os.path.isdir(self.dir_results):
            os.mkdir(self.dir_results)

        # Other
        self.best_auc = 0.
        self.init_time = 0
        self.lr = lr
        self.L_epoch = 0
        self.L_lc = []
        self.Lce_lc_val = []
        self.macro_auc_lc_val = []
        self.macro_auc_lc_train = []
        self.i_epoch = 0
        self.epochs = 0
        self.i_iteration = 0
        self.iterations = 0
        self.network = network
        self.test_generator = []
        self.train_generator = []
        self.preds_train = []
        self.refs_train = []
        self.pMIL = pMIL
        self.alpha_ce = alpha_ce
        self.best_criterion = 0
        self.best_epoch = 0
        self.metrics = {}
        self.id = id
        self.early_stopping = early_stopping
        self.scheduler = scheduler
        self.virtual_batch_size = virtual_batch_size
        self.constrain_cumpliment_lc = []
        self.constrain_proportion_lc = []
        self.criterion = criterion
        self.alpha_H = alpha_H
        self.H_iteration = 0.
        self.H_epoch = 0.

        # Set optimizers
        self.params = list(self.network.parameters())

        if self.pMIL:
            self.Lp_iteration = 0
            self.Lp_epoch = 0
            self.Lp_lc = []
            self.m = margin
            self.t_ic = t_ic
            self.t_pc = t_pc
            self.alpha_ic = alpha_ic
            self.alpha_pc = alpha_pc
            self.constrain_cumpliment = 0.
            self.constraint_proportion = 0.

        self.opt = torch.optim.SGD(self.params, lr=self.lr)
        #self.opt = torch.optim.Adam(self.params, lr=self.lr)

        # Set losses
        # if network.mode == 'instance':
        self.L = torch.nn.BCELoss().cuda()

    def train(self, train_generator, val_generator, test_generator, epochs):
        self.epochs = epochs
        self.iterations = len(train_generator)
        self.train_generator = train_generator
        self.val_generator = val_generator
        self.test_generator = test_generator
        self.preds_train = []
        self.refs_train = []

        # Move network to gpu
        self.network.cuda()

        self.init_time = timer()
        for i_epoch in range(epochs):
            self.i_epoch = i_epoch
            # init epoch losses
            self.L_epoch = 0
            self.Lpc_iteration = 0
            self.Lic_iteration = 0
            self.Lic_epoch = 0
            self.Lpc_epoch = 0
            self.H_iteration = 0.
            self.H_epoch = 0.
            self.constrain_cumpliment_iteration = 0.
            self.constrain_cumpliment_epoch = 0.
            self.constrain_proportion_epoch = 0.
            self.constrain_ic_proportion_epoch = 0.
            self.j = 0.
            self.jj = 0.
            n = 0
            nn = 0

            if self.scheduler:
                if (self.i_epoch + 1) % 50 == 0:
                    for g in self.opt.param_groups:
                        g['lr'] = self.lr / 2

            # Loop over training dataset
            print('[Training]: at bag level...')
            for self.i_iteration, (X, Y, O, X_augm) in enumerate(self.train_generator):

                X = torch.tensor(X).cuda().float()
                if X_augm is None:
                    X_augm = X
                else:
                    X_augm = torch.tensor(X_augm).cuda().float()
                Y = torch.tensor(Y).cuda().float()

                # Set model to training mode and clear gradients
                self.network.train()

                # Forward network
                Yhat, yhat, features = self.network(X_augm)

                # if self.network.mode == 'instance':
                Yhat = torch.clip(Yhat, min=0.01, max=0.98)

                # Estimate losses
                Lce = self.L(Yhat, torch.squeeze(Y))

                # Update overall losses
                L = Lce * self.alpha_ce

                if self.alpha_H > 0:
                    H = torch.mean(-torch.sum(yhat * torch.log(yhat + 1e-12), dim=(-1)))
                    self.H_iteration = H

                    L += - self.alpha_H * self.H_iteration


                # Backward gradients
                L = L / self.virtual_batch_size
                L.backward()

                # Update weights and clear gradients
                if ((self.i_epoch + 1) % self.virtual_batch_size) == 0:
                    self.opt.step()
                    self.opt.zero_grad()

                ######################################
                ## --- Iteration/Epoch end

                # Save predictions
                self.preds_train.append(Yhat.detach().cpu().numpy())
                self.refs_train.append(Y.detach().cpu().numpy())

                # Display losses per iteration
                self.display_losses(self.i_epoch + 1, self.epochs, self.i_iteration + 1, self.iterations,
                                    Lce.cpu().detach().numpy(),
                                    end_line='\r')

                # Update epoch's losses
                self.L_epoch += Lce.cpu().detach().numpy() / len(self.train_generator)
                

            # Epoch-end processes
            

            self.on_epoch_end()

            if self.early_stopping:
                if self.i_epoch + 1 == (self.best_epoch + 20):
                    break

    def on_epoch_end(self):

        # Obtain epoch-level metrics
        macro_auc = roc_auc_score(np.squeeze(np.array(self.refs_train)), np.array(self.preds_train), multi_class='ovr')
        self.macro_auc_lc_train.append(macro_auc)

        # Display losses
        self.display_losses(self.i_epoch + 1, self.epochs, self.iterations, self.iterations, self.L_epoch, macro_auc,
                            end_line='\n')
        # Update learning curves
        self.L_lc.append(self.L_epoch)

        # Obtain results on validation set
        Lce_val, macro_auc_val = self.test_bag_level_classification(self.val_generator)

        # Save loss value into learning curve
        self.Lce_lc_val.append(Lce_val)
        self.macro_auc_lc_val.append(macro_auc_val)

        metrics = {'epoch': self.i_epoch + 1, 'AUCtrain': np.round(self.macro_auc_lc_train[-1], 4),
                   'AUCval': np.round(self.macro_auc_lc_val[-1], 4)}
        with open(self.dir_results + self.id + 'metrics.json', 'w') as fp:
            json.dump(metrics, fp)
        print(metrics)

        if (self.i_epoch + 1) > 10:
            if self.criterion == 'auc':
                if self.best_criterion < self.macro_auc_lc_val[-1]:
                    self.best_criterion = self.macro_auc_lc_val[-1]
                    self.best_epoch = (self.i_epoch + 1)

                    torch.save(self.network, self.dir_results + self.id + 'network_weights_best.pth')

            elif self.criterion == 'z':
                if self.best_criterion < (-self.constrain_proportion_epoch):
                    self.best_criterion = -self.constrain_proportion_epoch
                    self.best_epoch = (self.i_epoch + 1)

                    torch.save(self.network, self.dir_results + self.id + 'network_weights_best.pth')

        # Each xx epochs, test models and plot learning curves
        if (self.i_epoch + 1) % 5 == 0:
            # Save weights
            torch.save(self.network, self.dir_results + self.id + 'network_weights.pth')

            # Plot learning curve
            self.plot_learning_curves()

            # Test at instance level
            X = self.test_generator.dataset.X[self.test_generator.dataset.y_instances[:, 0] != -1, :, :, :]
            Y = self.test_generator.dataset.y_instances[self.test_generator.dataset.y_instances[:, 0] != -1, :]
            acc, f1, k2 = self.test_instance_level_classification(X, Y, self.test_generator.dataset.classes)

        if (self.epochs == (self.i_epoch + 1)) or (self.early_stopping and (self.i_epoch + 1 == (self.best_epoch + 20))):
            print('-' * 20)
            print('-' * 20)

            self.network = torch.load(self.dir_results + self.id + 'network_weights_best.pth')

            # Obtain results on validation set
            Lce_val, macro_auc_val = self.test_bag_level_classification(self.val_generator)

            # Obtain results on validation set
            Lce_test, macro_auc_test = self.test_bag_level_classification(self.test_generator)

            # Test at instance level
            X = self.test_generator.dataset.X[self.test_generator.dataset.y_instances[:, 0] != -1, :, :, :]
            Y = self.test_generator.dataset.y_instances[self.test_generator.dataset.y_instances[:, 0] != -1, :]
            acc, f1, k2 = self.test_instance_level_classification(X, Y, self.test_generator.dataset.classes)

            metrics = {'epoch': self.best_epoch, 'AUCtest': np.round(macro_auc_test, 4),
                       'AUCval': np.round(macro_auc_val, 4), 'acc': np.round(acc, 4),
                       'f1': np.round(f1, 4), 'k2': np.round(k2, 4),
                       }

            if self.alpha_pc:
                metrics['constrain_cumpliment'] = np.round(self.constrain_cumpliment_lc[self.best_epoch-1], 4)
                metrics['constrain_proportion'] = np.round(self.constrain_proportion_lc[self.best_epoch-1], 4)

            with open(self.dir_results + self.id + 'best_metrics.json', 'w') as fp:
                json.dump(metrics, fp)
            print(metrics)

            self.metrics = metrics
            print('-' * 20)
            print('-' * 20)

    def plot_learning_curves(self):
        def plot_subplot(axes, x, y, y_axis):
            axes.grid()
            for i in range(x.shape[0]):
                axes.plot(x[i, :], y[i, :], 'o-')
            axes.set_ylabel(y_axis)

        fig, axes = plt.subplots(2, 1, figsize=(20, 15))
        plot_subplot(axes[0], np.tile(np.arange(self.i_epoch + 1), (2, 1)) + 1, np.array([self.L_lc, self.Lce_lc_val]), "Lce")
        plot_subplot(axes[1], np.tile(np.arange(self.i_epoch + 1), (2, 1)) + 1, np.array([self.macro_auc_lc_train, self.macro_auc_lc_val]), "mAUC")

        plt.savefig(self.dir_results + self.id + 'learning_curve.png')

    def display_losses(self, i_epoch, epochs, iteration, total_iterations, Lce, macro_auc=0, end_line=''):

        info = "[INFO] Epoch {}/{}  -- Step {}/{}: Lce={:.4f} ; AUC={:.4f}".format(
                i_epoch, epochs, iteration, total_iterations, Lce, macro_auc)

        if self.alpha_H > 0:
            if end_line == '\n':
                info += ' ; H=' + str(np.round(self.H_epoch, 4))
            else:
                info += ' ; H=' + str(np.round(self.H_iteration.cpu().detach().numpy(), 4))

        if self.pMIL and end_line == '\n':
            if self.alpha_pc > 0:
                info += ' ; IC=' + str(np.round(self.Lic_epoch, 4))
                info += '{' + str(np.round(self.constrain_ic_proportion_epoch, 4)) + '}'
            if self.alpha_ic > 0:
                info += ' ; PC=' + str(np.round(self.Lpc_epoch, 4))
                info += '{' + str(np.round(self.constrain_cumpliment_epoch, 4)) + '}'
                info += '{' + str(np.round(self.constrain_proportion_epoch, 4)) + '}  '
        if self.pMIL and end_line == '\r':
            if self.alpha_pc > 0:
                info += ' ; IC=' + str(np.round(self.Lic_iteration.cpu().detach().numpy(), 4))
            if self.alpha_ic > 0:
                info += ' ; PC=' + str(np.round(self.Lpc_iteration.cpu().detach().numpy(), 4))
                info += '{' + str(np.round(self.constrain_cumpliment_iteration, 4)) + '}  '

        # Print losses
        et = str(datetime.timedelta(seconds=timer() - self.init_time))
        print(info + ',ET=' + et, end=end_line)

    def test_instance_level_classification(self, X, Y, classes):
        classes = ['NC'] + classes

        self.network.eval()
        print(['INFO: Testing at instance level...'])

        Yhat = []
        for iInstance in np.arange(0, X.shape[0]):
            print(str(iInstance+1) + '/' + str(X.shape[0]), end='\r')

            # Tensorize input
            x = torch.tensor(X[iInstance, :, :, :]).cuda().float()
            x = x.unsqueeze(0)

            
            # Make prediction
            yhat = torch.softmax(- torch.cdist(torch.squeeze(self.network.bb(x)).unsqueeze(0), self.network.C, p=2.0), 1)
            yhat = torch.argmax(yhat).detach().cpu().numpy()

            Yhat.append(yhat)
        Yhat = np.array(Yhat)
        Y = np.argmax(Y, 1)

        cr = classification_report(Y, Yhat, target_names=classes, digits=4)
        acc = accuracy_score(Y, Yhat)
        f1 = f1_score(Y, Yhat, average='macro')
        cm = confusion_matrix(Y, Yhat)
        k2 = cohen_kappa_score(Y, Yhat, weights='quadratic')

        print('Instance Level kappa: ' + str(np.round(k2, 4)), end='\n')

        f = open(self.dir_results + self.id + 'report.txt', 'w')
        f.write('Title\n\nClassification Report\n\n{}\n\nConfusion Matrix\n\n{}\n\nKappa\n\n{}\n'.format(cr, cm, k2))
        f.close()

        return acc, f1, k2

    def test_bag_level_classification(self, test_generator, binary=False):
        self.network.eval()
        print('[VALIDATION]: at bag level...')

        # Loop over training dataset
        Y_all = []
        Yhat_all = []
        Lce_e = 0
        for self.i_iteration, (X, Y, O, _) in enumerate(test_generator):
            X = torch.tensor(X).cuda().float()
            Y = torch.tensor(Y).cuda().float()

            # Set model to training mode and clear gradients

            # Forward network
            Yhat, _, _ = self.network(X)
            # Estimate losses
            Lce = self.L(Yhat, torch.squeeze(Y))
            Lce_e += Lce.cpu().detach().numpy() / len(test_generator)

            Y_all.append(Y.detach().cpu().numpy())
            Yhat_all.append(Yhat.detach().cpu().numpy())

            # Display losses per iteration
            self.display_losses(self.i_epoch + 1, self.epochs, self.i_iteration + 1, len(test_generator),
                                Lce.cpu().detach().numpy(),
                                end_line='\r')
        # Obtain overall metrics
        Yhat_all = np.array(Yhat_all)
        Y_all = np.squeeze(np.array(Y_all))

        if binary:
            Yhat_all = np.max(Yhat_all, 1)
            Y_all = np.max(Y_all, 1)

        macro_auc = roc_auc_score(Y_all, Yhat_all, multi_class='ovr')

        # Display losses per epoch
        self.display_losses(self.i_epoch + 1, self.epochs, self.i_iteration + 1, len(test_generator),
                            Lce_e, macro_auc,
                            end_line='\n')

        return Lce_e, macro_auc

## main

In [32]:
def main(args, reduced_data=False):

    metrics = []
    #for i_iteration in np.arange(0, args.iterations):
    # why 3 times???
    #    id = str(i_iteration) + '_'      
    
    
    print(glob.glob(args.dir_csv_data))

    csv_list = []
    for i, filename in enumerate(glob.glob(args.dir_csv_data)):
        df = pd.read_csv(filename, delimiter=";")
        # df["dataset_type"] = [i]*len(df.index)
        csv_list.append(df)
    csv_data = pd.concat(csv_list, axis=0, ignore_index=False)

    # csv_data = csv_data[self.csv_data["mode"].str.contains(mode)]

    if reduced_data:
        csv_data = csv_data.sample(frac=1).reset_index(drop=True)
        csv_data = csv_data.head(200)



    # Set data generators
    dataset_train = MILDataset(csv_data[csv_data['mode'] == 'train'], args.classes,
                               bag_id='slide_name', input_shape=args.input_shape,
                               data_augmentation=args.data_augmentation,
                               pMIL=args.pMIL, proportions=args.proportions)
    data_generator_train = MILDataGenerator(dataset_train, batch_size=1, shuffle=True, max_instances=512)

    dataset_val = MILDataset(csv_data[csv_data['mode'] == 'val'], args.classes,
                             bag_id='slide_name', input_shape=args.input_shape,
                             data_augmentation=args.data_augmentation,
                             pMIL=args.pMIL, proportions=args.proportions)
    data_generator_val = MILDataGenerator(dataset_val, batch_size=1, shuffle=False, max_instances=512)

    dataset_test = MILDataset(csv_data[csv_data['mode'] == 'test'], args.classes,
                              bag_id='slide_name', input_shape=args.input_shape,
                              data_augmentation=args.data_augmentation,
                              pMIL=args.pMIL, proportions=args.proportions)
    data_generator_test = MILDataGenerator(dataset_test, batch_size=1, shuffle=False, max_instances=512)

    # Set network architecture
    network = MILArchitecture(args.classes, mode=args.mode, aggregation=args.aggregation,
                              backbone='vgg19', include_background=args.include_background)

    # Perform training
    trainer = MILTrainer(args.dir_results + args.experiment_name + '/', network,
                         lr=args.lr, pMIL=args.pMIL, margin=args.margin,
                         alpha_ic=args.alpha_ic, alpha_pc=args.alpha_pc, t_ic=args.t_ic,
                         t_pc=args.t_pc, alpha_ce=args.alpha_ce, id=id,
                         early_stopping=args.early_stopping, scheduler=args.scheduler,
                         virtual_batch_size=args.virtual_batch_size,
                         criterion=args.criterion,
                         alpha_H=args.alpha_H)
    trainer.train(train_generator=data_generator_train, val_generator=data_generator_val,
                  test_generator=data_generator_test, epochs=args.epochs)

    metrics.append([list(trainer.metrics.values())[1:]])

    # Get overall metrics
    metrics = np.squeeze(np.array(metrics))

    mu = np.mean(metrics, axis=0)
    std = np.std(metrics, axis=0)

    info = "AUCtest={:.4f}({:.4f}) ; AUCval={:.4f}({:.4f})  ; acc={:.4f}({:.4f}) ; f1-score={:.4f}({:.4f}) ; k2={:.4f}({:.4f})".format(
          mu[0], std[0], mu[1], std[1], mu[2], std[2], mu[3], std[3], mu[4], std[4])
    if args.alpha_pc > 0:
        info += " ; constrain_cumpliment={:.4f}({:.4f}) ; constrain_proportion={:.4f}({:.4f})".format(
          mu[5], std[5], mu[6], std[6])

    f = open(args.dir_results + args.experiment_name + '/' + 'method_metrics.txt', 'w')
    f.write(info)
    f.close()        

args = Arguments()
main(args)

['C:/Users/Prinzessin/projects/decentnet/datasceyence/data_prep\\mil_data_octa500_unknown.csv', 'C:/Users/Prinzessin/projects/decentnet/datasceyence/data_prep\\mil_data_ravir_unknown.csv']


AttributeError: 'MILDataset' object has no attribute 'images'