#### Data loader

In [None]:
import torch
import numpy as np
import torch.utils.data as data_utils
from torchvision import datasets, transforms

from dataloaders.dataloader import ColoredMnistBags3_binary, ColoredFashionMnistBags3_binary

def mi_collate_img(batch):
    # collate_fn for pytorch DataLoader
    bag = [item[0] for item in batch]
    bag = torch.tensor(np.concatenate(bag, axis = 0))
    
    bag_idx = [item[2] for item in batch]
    bag_idx = torch.tensor(np.concatenate(bag_idx, axis = 0))
    
    label = [item[1] for item in batch]

    instance_label = [item[1] for item in label]
    instance_label = torch.tensor(np.concatenate(instance_label, axis = 0))
    bag_label = [item[0] for item in label]
    bag_label = torch.tensor(bag_label)
    return bag, bag_idx, bag_label, instance_label

### Training

In [None]:
from util import *
from models.networks_colormnist import cmil_mnist

from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
def training_procedure(FLAGS, input_dim, dataloader):
  device = torch.device('cuda') 
  model = cmil_mnist(FLAGS).to(device)
  # model.apply(weights_init)
  model.train()
  auto_encoder_optimizer = torch.optim.AdamW(model.parameters(), 
                                  lr=FLAGS.initial_learning_rate, weight_decay=FLAGS.weight_decay)
  best_loss = 1000000.

  for epoch in range(0, FLAGS.end_epoch):
      elbo_epoch = 0
      recon_epoch = 0
      y_epoch = 0
      KL_ins_epoch = 0
      for (i, batch) in enumerate(dataloader):
          bag, bag_idx, bag_label, _ = batch # do not use instance label while training
          auto_encoder_optimizer.zero_grad()            
          elbo, class_y_loss, reconstruction_proba, KL_instance = \
              model.loss_function(bag.float().to(device), bag_idx.to(device), bag_label.to(device), epoch)
          elbo.backward()
          auto_encoder_optimizer.step()  
            
          elbo_epoch  += elbo
          recon_epoch += reconstruction_proba
          y_epoch += class_y_loss
          KL_ins_epoch += KL_instance
      elbo_epoch = elbo_epoch / (dataloader.__len__()/batch_size)
      recon_epoch = recon_epoch / (dataloader.__len__()/batch_size)
      y_epoch = y_epoch / (dataloader.__len__()/batch_size)
      KL_ins_epoch = KL_ins_epoch / (dataloader.__len__()/batch_size)
      
      if ((epoch + 1) % 20 ==0):
          print('Epoch #' + str(epoch+1) + '..............................................')
          print("Train elbo  {:.5f}, recon_loss {:.5f}, y_loss {:.5f}" \
                .format (elbo_epoch, recon_epoch,  y_epoch))
          print("KL zx  {:.6f}".format (KL_ins_epoch))
      if elbo_epoch < best_loss:
            best_loss = elbo_epoch
            torch.save(model.state_dict(), '/home/temp_cmnist.pt')
  
  model.load_state_dict(torch.load('/home/temp_cmnist.pt'))
  return model

In [None]:
def weight_reset(m):
   if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
     m.reset_parameters()

if __name__ == '__main__':
    import argparse
    import torchvision
    import numpy as np
    import pandas as pd
    from sklearn.model_selection  import ParameterGrid
    import matplotlib.pyplot as plt
    param_grid = {'instance_dim': [24], 'aux_loss_multiplier_y': [10]}
    grid = ParameterGrid(param_grid)    
    for params in grid:
        parser = argparse.ArgumentParser()
        parser.add_argument('--cuda', type=bool, default=True, help="run the following code on a GPU")
        parser.add_argument('--num_classes', type=int, default=2, help="number of classes on which the data set trained")
        parser.add_argument('--initial_learning_rate', type=float, default=1e-3, help="starting learning rate")
        parser.add_argument("--weight-decay", default=1e-4, type=float)
        parser.add_argument('--instance_dim', type=int, default=params['instance_dim'], help="dimension of instance factor latent space")
        parser.add_argument('--reconstruction_coef', type=float, default=1., help="coefficient for reconstruction term")
        parser.add_argument('--kl_divergence_coef', type=float, default=1, help="coefficient for instance KL-Divergence loss term")
        parser.add_argument('--aux_loss_multiplier_y', type=float, default=params['aux_loss_multiplier_y'])
        parser.add_argument('--start_epoch', type=int, default=0, help="flag to set the starting epoch for training")
        parser.add_argument('--end_epoch', type=int, default=200, help="flag to indicate the final epoch of training")
        parser.add_argument('-w', '--warmup', type=int, default=0, metavar='N', help='number of epochs for warm-up. Set to 0 to turn warmup off.')
        parser.add_argument('--in_channels', type=int, default=3, metavar='N', help='number of input channels, used for colored mnists')

        FLAGS = parser.parse_args(args=[])

        batch_size = 1024
        bag_size = 10
        
#         task = 'fashion'
        task = 'mnist'
        print('Current Task is', task)
        if task == 'mnist':
            train_loader = data_utils.DataLoader(ColoredMnistBags3_binary(
                                          mean_bag_length=bag_size,
                                          var_bag_length=0,
                                          num_bag=40000//bag_size,
                                          seed=1,
                                          train=True),
                                          batch_size=batch_size, pin_memory=True,num_workers=0,
                                          shuffle=True,collate_fn=mi_collate_img)
            path = '/home/weights/C_binary_bagsize_'+str(bag_size) + '.pt'
        else:
            train_loader = data_utils.DataLoader(ColoredFashionMnistBags3_binary(
                                          mean_bag_length=bag_size,
                                          var_bag_length=0,
                                          num_bag=40000//bag_size,
                                          train=True),
                                          batch_size=batch_size, pin_memory=True,num_workers=0,
                                          shuffle=True,collate_fn=mi_collate_img)
            path = '/home/weights/C_Fashion_binary_bagsize_'+str(bag_size) + '.pt'

        model = training_procedure(FLAGS, (1,28,28), train_loader)
        torch.save(model.state_dict(), path)
        model.apply(weight_reset)

#### Evaluation all bags at once

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_fscore_support, accuracy_score, roc_curve

def get_accuracy(model, bags, bag_idx, bag_label, instance_label, thresh=0.5):
    with torch.no_grad():
        pred_instance = model.classifier_ins(bags, bag_idx)
    instance_auc = roc_auc_score(instance_label.cpu(), pred_instance.cpu())
    instance_aucpr = average_precision_score(instance_label.cpu(), pred_instance.cpu())
    
    temp = [m > thresh for m in pred_instance.cpu().squeeze().numpy()]
    accuracy = accuracy_score(instance_label.cpu(), temp)
    precision, recall, fscore, _ = precision_recall_fscore_support(instance_label.cpu(), temp)
    return instance_auc, instance_aucpr, accuracy, precision, recall, fscore

auc_list = []
aucpr_list = []
acc_list = []
precision_list = []
recall_list = []
f1_list = []

bag_size_test = 2
# task = 'mnist'
print(task)
if task == 'mnist':
    test_loader = data_utils.DataLoader(ColoredMnistBags3_binary(
                                  mean_bag_length=bag_size_test,
                                  var_bag_length=0,
                                  num_bag=20000//bag_size_test,
                                  train=False),
                                  batch_size=20000//bag_size_test, pin_memory=True,num_workers=0,
                                  shuffle=True,collate_fn=mi_collate_img)
    path = '/home/weights/C_binary_bagsize_'+str(bag_size) + '.pt'
else:
    test_loader = data_utils.DataLoader(ColoredFashionMnistBags3_binary(
                                  mean_bag_length=bag_size_test,
                                  var_bag_length=0,
                                  num_bag=20000//bag_size_test,
                                  train=False), # use test set
                                  batch_size=20000//bag_size,
                                  shuffle=True,collate_fn=mi_collate_img)
    path = '/home/weights/C_Fashion_binary_bagsize_'+str(bag_size) + '.pt'

test_bag, test_bag_idx, test_bag_label, test_instance_label = iter(test_loader).next()

print("number of instances", len(test_instance_label))
print("number of positive instances", sum(test_instance_label))
print("number of positive bags", sum(test_bag_label))


model = cmil_mnist(FLAGS).to(torch.device('cuda'))
model.load_state_dict(torch.load(path))
device = torch.device('cuda')

test_auc, test_aucpr, accuracy, precision, recall, fscore = get_accuracy(model, test_bag.float().to(device), test_bag_idx, test_bag_label, test_instance_label)
print(test_auc)
print(test_aucpr)
print('Accuracy', accuracy)
print('precision', precision)
print('recall', recall)

auc_list.append(test_auc)
aucpr_list.append(test_aucpr)
acc_list.append(accuracy)
precision_list.append(precision)
recall_list.append(recall)
f1_list.append(fscore)

print('Mean instance prediction AUC is： {:.5f}, AUC-PR: {:.5f}, ACC: {:.5f}'.format(np.mean(auc_list), np.mean(aucpr_list), np.mean(acc_list) ))
print('Mean instance prediction Precision is： {:.5f}, Recall: {:.5f}, Fscore: {:.5f}'.format(np.mean(precision_list), np.mean(recall_list), np.mean(f1_list) ))
