In [None]:
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

#### Data loader

In [None]:
import numpy as np
import torch
import torch.utils.data as data_utils
from torchvision import datasets, transforms

def mi_collate_img(batch):
    # collate_fn for pytorch DataLoader
    bag = [item[0] for item in batch]
    bag = torch.tensor(np.concatenate(bag, axis = 0))
    
    bag_idx = [item[2] for item in batch]
    bag_idx = torch.tensor(np.concatenate(bag_idx, axis = 0))
    
    label = [item[1] for item in batch]

    instance_label = [item[1] for item in label]
    instance_label = torch.tensor(np.concatenate(instance_label, axis = 0))
    bag_label = [item[0] for item in label]
    bag_label = torch.tensor(bag_label)
    return bag, bag_idx, bag_label, instance_label


from dataloaders.dataloader import KMnistBags3, MnistBags3, FashionMnistBags3

### Networks

In [None]:
from util import *
from models.networks_mnists import cmil_mnist

### Training

In [None]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import roc_curve
def training_procedure(FLAGS, input_dim, dataloader):
    device = torch.device('cuda') 
    model = cmil_mnist(FLAGS).to(device)
    # model.apply(weights_init)
    model.train()
    auto_encoder_optimizer = torch.optim.AdamW(model.parameters(), 
                                  lr=FLAGS.initial_learning_rate, weight_decay=FLAGS.weight_decay)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(auto_encoder_optimizer, factor=0.1, patience=20, verbose=True)

    best_loss = 1000000.

    for epoch in range(0, FLAGS.end_epoch):
      elbo_epoch = 0
      recon_epoch = 0
      y_epoch = 0
      KL_ins_epoch = 0
      for (i, batch) in enumerate(dataloader):
          bag, bag_idx, bag_label, _ = batch
          auto_encoder_optimizer.zero_grad()            
          elbo, class_y_loss, reconstruction_proba, KL_instance = \
              model.loss_function(bag.float().to(device), bag_idx.to(device), bag_label.to(device), epoch)

          elbo.backward()
          auto_encoder_optimizer.step()  

          elbo_epoch  += elbo
          recon_epoch += reconstruction_proba
          y_epoch += class_y_loss
          KL_ins_epoch += KL_instance
      elbo_epoch = elbo_epoch / (dataloader.__len__()/batch_size)
      recon_epoch = recon_epoch / (dataloader.__len__()/batch_size)
      y_epoch = y_epoch / (dataloader.__len__()/batch_size)
      KL_ins_epoch = KL_ins_epoch / (dataloader.__len__()/batch_size)

      scheduler.step(y_epoch)

      if ((epoch + 1) % 10 ==0):
          print('Epoch #' + str(epoch+1) + '..............................................')
          print("Train elbo  {:.5f}, recon_loss {:.5f}, y_loss {:.5f}" \
                .format (elbo_epoch, recon_epoch,  y_epoch))
          print("KL zx  {:.6f}".format (KL_ins_epoch))
        
      if elbo_epoch < best_loss:
            best_loss = elbo_epoch
            torch.save(model.state_dict(), '/home/weijia/temp_'+ FLAGS.task +'.pt')

    model.load_state_dict(torch.load('/home/weijia/temp_'+ FLAGS.task +'.pt'))
    return model

In [None]:
def weight_reset(m):
   if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose2d):
     m.reset_parameters()

if __name__ == '__main__':
    import argparse
    import torchvision
    import numpy as np
    import pandas as pd
    from sklearn.model_selection  import ParameterGrid
    import matplotlib.pyplot as plt
    param_grid = {'instance_dim': [24], 'aux_loss_multiplier_y': [1000]}
    grid = ParameterGrid(param_grid)    
    

#     task = 'fashion'
    # task = 'mnist'
    task = 'kmnist'
    
    for params in grid:
        parser = argparse.ArgumentParser()
        parser.add_argument('--cuda', type=bool, default=True, help="run the following code on a GPU")
        parser.add_argument('--num_classes', type=int, default=2, help="number of classes on which the data set trained")
        parser.add_argument('--initial_learning_rate', type=float, default=1e-3, help="starting learning rate")
        parser.add_argument("--weight-decay", default=1e-4, type=float)
        parser.add_argument('--instance_dim', type=int, default=params['instance_dim'], help="dimension of instance factor latent space")
        parser.add_argument('--reconstruction_coef', type=float, default=1., help="coefficient for reconstruction term")
        parser.add_argument('--kl_divergence_coef', type=float, default=1, help="coefficient for instance KL-Divergence loss term")
        parser.add_argument('--aux_loss_multiplier_y', type=float, default=params['aux_loss_multiplier_y'])
        parser.add_argument('--start_epoch', type=int, default=0, help="flag to set the starting epoch for training")
        parser.add_argument('--end_epoch', type=int, default=200, help="flag to indicate the final epoch of training")
        parser.add_argument('-w', '--warmup', type=int, default=0, metavar='N', help='number of epochs for warm-up. Set to 0 to turn warmup off.')
        parser.add_argument('--in_channels', type=int, default = 1, help="input channels")
        parser.add_argument('--task', type=str, default = task, help="current task")

        FLAGS = parser.parse_args(args=[])

        batch_size = 2048
        bag_size = 30
        

        n_tasks = 10
        print('Current Task is', task)
        
        for target in range(0, n_tasks):
            print(target)
            if task == 'mnist':
                train_loader = data_utils.DataLoader(MnistBags3(target_number=target,
                                              mean_bag_length=bag_size,
                                              var_bag_length=0,
                                              num_bag=60000//bag_size,
                                              train=True),
                                              batch_size=batch_size, pin_memory=True,num_workers=0,
                                              shuffle=True,collate_fn=mi_collate_img)
                path = '/home/weijia/Code/weights/set_mnist_'+str(target) + '.pt'
            elif task =='fashion':
                train_loader = data_utils.DataLoader(FashionMnistBags3(target_number=target,
                                              mean_bag_length=bag_size,
                                              var_bag_length=0,
                                              num_bag=60000//bag_size,
                                              train=True),
                                              batch_size=batch_size, pin_memory=True,num_workers=0,
                                              shuffle=True,collate_fn=mi_collate_img)
                path = '/home/weijia/Code/weights/set_fmnist_'+str(target) + '.pt'
            elif task =='kmnist':
                train_loader = data_utils.DataLoader(KMnistBags3(target_number=target,
                                              mean_bag_length=bag_size,
                                              var_bag_length=0,
                                              num_bag=60000//bag_size,
                                              train=True),
                                              batch_size=batch_size, pin_memory=True,num_workers=0,
                                              shuffle=True,collate_fn=mi_collate_img)
                path = '/home/weijia/Code/weights/set_kmnist_'+str(target) + '.pt'
            
            model = training_procedure(FLAGS, (1,28,28), train_loader)
            torch.save(model.state_dict(), path)
            model.apply(weight_reset)

#### Evaluation all bags at once

In [None]:
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_fscore_support, accuracy_score, roc_curve
import sklearn

def get_accuracy_multiclass(model, bags, bag_idx, bag_label, instance_label, threshold=0.5):
    with torch.no_grad():
        pred_instance = model.classifier_ins(bags, bag_idx)
    temp = [m > threshold for m in pred_instance.cpu().squeeze().numpy()]
    
    precision = sklearn.metrics.precision_score(instance_label, temp) # tp/ tp+fp
    recall = sklearn.metrics.recall_score(instance_label, temp) # tp/ tp + fn
    
    tp = precision * sum(temp)
    print('tp', tp)
    fp = sum(temp) - tp
    print('fp',fp)
    
    
    instance_auc = roc_auc_score(instance_label.cpu(), pred_instance.cpu())
    instance_aucpr = average_precision_score(instance_label.cpu(), pred_instance.cpu())
    
    accuracy = accuracy_score(instance_label.cpu(), temp)
    precision, recall, fscore, _ = precision_recall_fscore_support(instance_label.cpu(), temp, average='binary')
    return instance_auc, instance_aucpr, accuracy, precision, recall, fscore, tp, fp

auc_list = []
aucpr_list = []
acc_list = []
precision_list = []
recall_list = []
f1_list = []
tp_list = []
fp_list = []

# task = 'fashion'
# task = 'mnist'
# task = 'kmnist'
print('Current Task is', task)
for target in range(0,10):
    print(target)
    if task == 'mnist':
        test_loader = data_utils.DataLoader(MnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      seed=1,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_mnist_'+str(target) + '.pt'
    elif task =='fashion':
        test_loader = data_utils.DataLoader(FashionMnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_fmnist_'+str(target) + '.pt'
    elif task =='kmnist':
        test_loader = data_utils.DataLoader(KMnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_kmnist_'+str(target) + '.pt'
        
    test_bag, test_bag_idx, test_bag_label, test_instance_label = iter(test_loader).next()

    model = cmil_mnist(FLAGS).to(torch.device('cuda'))
    model.load_state_dict(torch.load(path))
    device = torch.device('cuda')

    test_auc, test_aucpr, accuracy, precision, recall, fscore,tp, fp = get_accuracy_multiclass(model, test_bag.float().to(device), 
                                                                           test_bag_idx, test_bag_label, test_instance_label)
    print('Target is', target)
    print('AUC', test_auc)
    print('MAP', test_aucpr)
    print('accuracy', accuracy)
    print('precision', precision)
    print('recall', recall)

    auc_list.append(test_auc)
    aucpr_list.append(test_aucpr)
    acc_list.append(accuracy)
    precision_list.append(precision)
    recall_list.append(recall)
    f1_list.append(fscore)
    tp_list.append(tp)
    fp_list.append(fp)
print('Mean instance prediction AUC is： {:.5f}, AUC-PR: {:.5f}, ACC: {:.5f}'.format(np.mean(auc_list), np.mean(aucpr_list), np.mean(acc_list) ))
print('Mean instance prediction Precision is： {:.5f}, Recall: {:.5f}, Fscore: {:.5f}'.format(np.mean(precision_list), np.mean(recall_list), np.mean(f1_list) ))

print('Mean instance prediction Accuracy: {:.5f}'.format(np.sum(tp_list)/(np.sum(tp_list)+np.sum(fp_list)) ))


#### Qualitative Evaluation Random Order

In [None]:
def show_img(img):
    plt.figure(figsize=(12,10))
    npimg = img.numpy()
    npimg = np.clip(npimg, 0., 1.)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

# task = 'fashion'
#         task = 'mnist'
#         task = 'kmnist'
print('Current Task is', task)
for target in range(0,n_tasks):
    print(target)
    if task == 'mnist':
        test_loader = data_utils.DataLoader(MnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_mnist_'+str(target) + '.pt'
    elif task =='fashion':
        test_loader = data_utils.DataLoader(FashionMnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_fmnist_'+str(target) + '.pt'
    elif task =='kmnist':
        test_loader = data_utils.DataLoader(KMnistBags3(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_kmnist_'+str(target) + '.pt'
        
    
    temp2 = iter(test_loader)
    batch2= next(temp2)

    model = cmil_mnist(FLAGS).to(torch.device('cuda'))
    model.load_state_dict(torch.load(path))


    # show_img(torchvision.utils.make_grid(img[0:100].cpu(), nrow=10, ncol=10))
    # show_img(torchvision.utils.make_grid(batch[0][0:100], nrow=10, ncol=10))


    temp2 = model.get_encoding(batch2[0].to(torch.device('cuda')),batch2[1].to(torch.device('cuda')))
    img2 = model.reconstruct(temp2)

    print(temp2.shape)
    print(img2.shape)

    show_img(torchvision.utils.make_grid(img2[0:100].cpu(), nrow=10, ncol=10))
    show_img(torchvision.utils.make_grid(batch2[0][0:100], nrow=10, ncol=10))

### Qualitative Ordered

In [None]:
from dataloaders.dataloader_draw import KMnistBagsDraw, MnistBagsDraw, FashionMnistBagsDraw

def show_img(img):
    plt.figure(figsize=(12,12))
    npimg = img.numpy()
    npimg = np.clip(npimg, 0., 1.)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.axis('off')
    plt.show()

print('Current Task is', task)
for target in range(0,10):
    print(target)
    if task == 'mnist':
        test_loader = data_utils.DataLoader(MnistBagsDraw(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      seed=1,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_mnist_'+str(target) + '.pt'
    elif task =='fashion':
        test_loader = data_utils.DataLoader(FashionMnistBagsDraw(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_fmnist_'+str(target) + '.pt'
    elif task =='kmnist':
        test_loader = data_utils.DataLoader(KMnistBagsDraw(target_number=target,
                                      mean_bag_length=bag_size,
                                      var_bag_length=0,
                                      num_bag=60000//bag_size,
                                      train=False),
                                      batch_size=batch_size, pin_memory=True,num_workers=0,
                                      shuffle=True,collate_fn=mi_collate_img)
        path = '/home/weijia/Code/weights/set_kmnist_'+str(target) + '.pt'
        

    model = cmil_mnist(FLAGS).to(torch.device('cuda'))
    model.load_state_dict(torch.load(path))

    temp3 = iter(test_loader)
    batch3 = next(temp3)
    temp3 = model.get_encoding(batch3[0].to(torch.device('cuda')),batch3[1].to(torch.device('cuda')))
    img3 = model.reconstruct(temp3)

    show_img(torchvision.utils.make_grid(batch3[0][0:100], nrow = 10, ncol=10))
    show_img(torchvision.utils.make_grid(img3[0:100].cpu(), nrow = 10, ncol=10))