The code is based on implementation (https://github.com/AMLab-Amsterdam/AttentionDeepMIL) of the paper [1]. The changes are made according to the paper [2]. <br>
[1] Ilse, Maximilian, Jakub Tomczak, and Max Welling "Attention-based deep multiple instance learning." International conference on machine learning. PMLR, 2018 <br>
[2] "The Effect of Within-Bag Sampling on End-to-End Multiple Instance Learning", 2021

In [None]:
# External dependencies
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms
import argparse
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import os
import shutil
from pathlib import Path
from PIL import Image
# Internal modules
import attention_model
from dataloaders import ImagenetteBags
from attention_model import Attention_Imagenette_bags_3GPUs, Attention_Imagenette_bags_1GPU
from evaluation import compute_metrics

In [None]:
def list_files_in_folder(image_folder):
    """Lists file names in a given directory"""
    list_of_files = []
    for file in os.listdir(image_folder):
        if os.path.isfile(os.path.join(image_folder, file)):
            list_of_files.append(file)
    return list_of_files

def create_save_dir(direct, name_subdirectory):
    if not os.path.exists(os.path.join(direct, name_subdirectory)):
        print('make dir')
        os.mkdir(os.path.join(direct, name_subdirectory))
    return os.path.join(direct, name_subdirectory)

In [None]:
# Training settings
parser = argparse.ArgumentParser(description='PyTorch Imegenette bags Example')
parser.add_argument('--epochs', type=int, default=50, metavar='N',
                    help='number of epochs to train (default: 20)')
parser.add_argument('--lr', type=float, default=0.0001, metavar='LR',
                    help='learning rate (default: 0.0005)')
parser.add_argument('--reg', type=float, default=10e-5, metavar='R',
                    help='weight decay')
parser.add_argument('--num_bags_test', type=int, default=10, metavar='NTest',
                    help='number of bags in test set')
parser.add_argument('--seed', type=int, default=1, metavar='S',
                    help='random seed (default: 1)')

# args = parser.parse_args()
args, unknown = parser.parse_known_args()

data_path = './fold1/Imagenette_0010_0050_0030'
path = Path(data_path)
mean_data = [0.485, 0.456, 0.406]
std_data = [0.229, 0.224, 0.225]
# Parameters derived from dataset name (as in "Create_Imagenette_bags_dataset.ipynb"), otherwise enter
num_bags_train = int(path.parts[-1][11:15])
mean_bag_length = int(path.parts[-1][16:-5])
print(num_bags_train,mean_bag_length)
# Three or one GPU: three_gpus is True if two GPUs are used, False if one
three_gpus = False
args.sampling_percent = 30
image_sizeImagenette = (3,112,112)

sampling_size_in_instances = np.ceil((args.sampling_percent*mean_bag_length)/100)
N_repeat = 10 # Approximate number of times one image is sampled to bags when sampling percent is lower than 100
if args.sampling_percent<100:
    test_epochs = np.ceil((N_repeat*mean_bag_length)/sampling_size_in_instances).astype(np.int)
    val_times = 5 # The number of times validation is repeated (with resampling of images from the bags) 
else:
    test_epochs = 1; val_times = 1
window_length = 15 # the moving average window length to find validation model with lowest error


In [None]:
args.cuda = torch.cuda.is_available()
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
    print('\nGPU is ON!')
print('Init Model')
if args.cuda:
    if three_gpus:
        model = Attention_Imagenette_bags_3GPUs() 
    else:
        model = Attention_Imagenette_bags_1GPU()
        model.to('cuda:0')
save_weights_dir = create_save_dir(data_path, 'test_weights_epochs_'+str(test_epochs)+'samplingPerc_'+
                                   str(args.sampling_percent)+'_train_epochs_'+str(args.epochs)+'_lr'+str(args.lr)) 
if os.listdir(save_weights_dir):
    print('Directory is not empty!')
optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.reg)


In [None]:
train_loader = data_utils.DataLoader(ImagenetteBags(train=True,
                                                    valid=False,
                                                    test=False,
                                                    image_size=image_sizeImagenette,
                                     transform=transforms.Compose([transforms.ToTensor(),
                                                                   transforms.Normalize(mean=mean_data, 
                                                                                        std=std_data)
                                     ]), sampling_size=sampling_size_in_instances, data_path=data_path), 
                                     batch_size=1,
                                     shuffle=True)

valid_loader = data_utils.DataLoader(ImagenetteBags(train=False,
                                                    valid=True,
                                                    test=False,
                                                    image_size=image_sizeImagenette,
                                    transform=transforms.Compose([transforms.ToTensor(),
                                                                  transforms.Normalize(mean=mean_data, 
                                                                                       std=std_data)
                                    ]), sampling_size=sampling_size_in_instances, data_path=data_path),
                                    batch_size=1,
                                    shuffle=False)


In [None]:
path = Path(data_path)
all_train_loss = []; all_train_error=[]; all_valid_loss=[]; all_valid_error=[]
count_stable_epochs=[]; 
if three_gpus:
    min_valid_loss = torch.tensor([float('inf')]).to('cuda:1') 
else:
    min_valid_loss = torch.tensor([float('inf')]).to('cuda:0') 
min_valid_error = np.inf; window_with_best_avg=np.inf
num_epochs=args.epochs; lr=args.lr; sampling_percent=args.sampling_percent
if not os.listdir(save_weights_dir): 
    for epoch in range(1, args.epochs + 1):
        print("Training started")
        model.train()
        train_loss = 0.; train_error = 0.
        for batch_idx, (data, label, sample_indices, index, bages_names) in enumerate(train_loader):

            bag_label = label[0]
            if args.cuda:
                if three_gpus:
                    data, bag_label = data.to('cuda:2'), bag_label.to('cuda:1') 
                else:
                    data, bag_label = data.to('cuda:0'), bag_label.to('cuda:0') 
            data, bag_label = Variable(data), Variable(bag_label)

            # reset gradients
            optimizer.zero_grad()
            # calculate loss and metrics
            loss, _ = model.calculate_objective(data, bag_label)
            train_loss += loss.data[0]
            error, _ = model.calculate_classification_error(data, bag_label)
            train_error += error

            # backward pass
            loss.backward()
            optimizer.step()
        # calculate loss and error for epoch
        train_loss /= len(train_loader)
        train_error /= len(train_loader)

        # Validation loss and error
        valid_loss = 0.; valid_error = 0.
        model.eval()
        for epp in range(0, val_times):
            for batch_idx, (data, label, sample_indices, index, bages_names) in enumerate(valid_loader):
                bag_label = label[0]
                if args.cuda:
                    if three_gpus:
                        data, bag_label = data.to('cuda:2'), bag_label.to('cuda:1') 
                    else:
                        data, bag_label = data.to('cuda:0'), bag_label.to('cuda:0') 
                data, bag_label = Variable(data), Variable(bag_label)  
                loss, _ = model.calculate_objective(data, bag_label)
                valid_loss += loss.data[0]
                error, _ = model.calculate_classification_error(data, bag_label)
                valid_error += error
        valid_loss /= len(valid_loader)*val_times
        valid_error /= len(valid_loader)*val_times

        # Save all losses for all epochs both valid and train
        all_train_loss.append(train_loss.cpu().data.numpy())
        all_train_error.append(train_error)
        all_valid_loss.append(valid_loss.cpu().data.numpy())
        all_valid_error.append(valid_error)

        # Best validation error within the best average window of error
        name_epoch_model = os.path.join(str(path), 'saved_model_IMAGENETTE_ResNet_'+str(path.parts[-1])+'_currentEpoch'+\
                                        str(epoch)+'_overallEpochs'+str(num_epochs)+'_lr'+str(lr)+'_sampPerc'+\
                                        str(sampling_percent)+'.pt')
        torch.save(model.state_dict(), name_epoch_model)
        if len(all_valid_error)>window_length-1:
            avg_valid_error_window = np.mean(np.asarray(all_valid_error)[-window_length:])

            if window_with_best_avg > avg_valid_error_window:
                idx_model_in_window_with_min_loss = np.argmin(np.asarray(all_valid_error)[-window_length:])
                window_with_best_avg = avg_valid_error_window
                name_best_inMovAv_val_error_model = os.path.join(str(path), 'saved_model_IMAGENETTE_ResNet_'+str(path.parts[-1])+
                                                                '_currentEpoch'+
                                                 str(epoch-window_length+idx_model_in_window_with_min_loss+1)+
                                                 '_overallEpochs'+str(num_epochs)+'_lr'+str(lr)+
                                                 '_sampPerc'+str(sampling_percent)+'.pt')
                print("Current best moving average model, epoch: ", str(epoch-window_length+idx_model_in_window_with_min_loss+1))
                epoch_best_movingAvg = epoch-window_length+idx_model_in_window_with_min_loss+1
        #to delete models that are outside current window except the best_window_model_name
        for ep in range(0, epoch-window_length):
            model_nametmp = os.path.join(str(path), 'saved_model_IMAGENETTE_ResNet_'+str(path.parts[-1])+'_currentEpoch'+
                                         str(ep+1)+'_overallEpochs'+str(num_epochs)+'_lr'+str(lr)+'_sampPerc'+
                                         str(sampling_percent)+'.pt')
            if model_nametmp != name_best_inMovAv_val_error_model and os.path.isfile(model_nametmp):
                os.remove(model_nametmp)

        print('Epoch: {}, Loss: {:.4f}, Train error: {:.4f}'.format(epoch, train_loss.cpu().numpy()[0], train_error))
        print('Epoch: {}, Loss: {:.4f}, Valid error: {:.4f}'.format(epoch, valid_loss.cpu().numpy()[0], valid_error))


    new_name_best_inMovAv_val_error_model = os.path.join(save_weights_dir, 'saved_model_best_movingAvg_IMAGENETTE_ResNet_'+\
                                                         str(path.parts[-1])+'epochSaved_'+str(epoch_best_movingAvg)+
                                                         '_overallEpochs'+str(num_epochs)+'_lr'+str(lr)+
                                                         '_sampPerc'+str(sampling_percent)+'.pt')
    shutil.move(name_best_inMovAv_val_error_model, new_name_best_inMovAv_val_error_model)


    for ep in range(1,num_epochs+1):
        model_nametmp = os.path.join(str(path), 'saved_model_IMAGENETTE_ResNet_'+str(path.parts[-1])+'_currentEpoch'+str(ep)+
                                     '_overallEpochs'+str(num_epochs)+'_lr'+str(lr)+'_sampPerc'+str(sampling_percent)+'.pt')
        if os.path.isfile(model_nametmp):
            os.remove(model_nametmp)


In [None]:
# With replacement
def test(save_weights_dir, subfolder, test_loader, model):
    model.eval()
    test_loss = 0.; test_error = 0.
    for epoch in range(0, test_epochs): 
        print('test_epoch', epoch, end='\r')
        for batch_idx, (data, label, sample_indices, index, bages_names) in enumerate(test_loader):
            bag_label = label[0]
            instance_labels = label[1]
            if args.cuda:
                if three_gpus:
                    data, bag_label = data.to('cuda:2'), bag_label.to('cuda:1') 
                else:
                    data, bag_label = data.to('cuda:0'), bag_label.to('cuda:0') 
            data, bag_label = Variable(data), Variable(bag_label)
            loss, attention_weights = model.calculate_objective(data, bag_label)
            test_loss += loss.data[0]
            error, predicted_label = model.calculate_classification_error(data, bag_label)
            test_error += error
            
            create_save_dir(save_weights_dir, subfolder)
            save_weights_bag = create_save_dir(os.path.join(save_weights_dir, subfolder), \
                                               str(bages_names[0].cpu().numpy()).zfill(4))
            
            for i in range(data.cpu().shape[1]):
                label_name = int(label[1][0][i].cpu().numpy())
                # Save attention weights
                np.save(os.path.join(save_weights_bag, str(all_names_test[index][sample_indices[0][i]])+'_'+\
                                     str(epoch).zfill(2)+'_'+str(label_name)+'_'+\
                                     str(bages_names[0].cpu().numpy()).zfill(4)+'_'+\
                                     str(predicted_label.cpu().numpy())+".npy"), \
                        attention_weights.cpu().data.numpy()[0][i])

In [None]:
test_loader = data_utils.DataLoader(ImagenetteBags(train=False,
                                                   valid=False,
                                                   test=True,
                                                   image_size=image_sizeImagenette,
                                    transform=transforms.Compose([transforms.ToTensor(),
                                                                  transforms.Normalize(mean=mean_data, 
                                                                                       std=std_data)
                                    ]), sampling_size=sampling_size_in_instances, data_path=data_path),
                                    batch_size=1,
                                    shuffle=False)
all_names_test = np.load(os.path.join(data_path,'test_imgs_lists.npy'), allow_pickle=True)
    
print('Start Testing')
if args.cuda:
    if three_gpus:
        model = Attention_Imagenette_bags_3GPUs() 
    else:
        model = Attention_Imagenette_bags_1GPU()
        model.to('cuda:0')
model.load_state_dict(torch.load(new_name_best_inMovAv_val_error_model))
subfolder='best_inMovAv_val_loss_model_weights'
if not os.listdir(save_weights_dir):
    test(save_weights_dir, subfolder, test_loader, model)


### Evaluation

In [None]:
num_test_bags = args.num_bags_test
compute_metrics('best_inMovAv_val_loss_model_weights', num_test_bags, save_weights_dir, data_path, test_epochs)