In [None]:
!pip install torch
!pip install torchensemble
!pip install torchvision

In [None]:
import numpy as np
import pandas as pd
import random

from sklearn.model_selection import train_test_split, KFold

import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataset import Subset
from torchensemble import FastGeometricClassifier, SnapshotEnsembleClassifier
from torchvision import datasets, transforms
from torchvision.models import resnet18

In [None]:
n_estimators = 10
lr = 1e-1
weight_decay = 5e-4
momentum = 0.9
epochs = 10 #200

batch_size = 256 #16

data_dir = './data'

In [None]:
def set_model_FGE():
  # Choose the Ensemble Method
  model = FastGeometricClassifier(
      estimator=resnet18(),
      # estimator_args={"bresnet18lock": BasicBlock, "num_blocks": [2, 2, 2, 2]},
      n_estimators=n_estimators,
      cuda=True,
  )

  # Set the Optimizer
  model.set_optimizer(
      "SGD", lr=lr, weight_decay=weight_decay, momentum=momentum
  )

  # Set the Scheduler
  model.set_scheduler("CosineAnnealingLR", T_max=epochs)

  return model

In [None]:
# not modified yet
def _adjust_lr_modeified(
    self, optimizer, epoch, i, n_iters, cycle, alpha_1, alpha_2
):
    """
    Set the internal learning rate scheduler for fast geometric ensemble.
    Please refer to the original paper for details.
    """

    def scheduler(i):
        t = ((epoch % cycle) + i) / cycle
        '''
        if t < 0.5:
            return alpha_1 * (1.0 - 2.0 * t) + alpha_2 * 2.0 * t
        else:
            return alpha_1 * (2.0 * t - 1.0) + alpha_2 * (2.0 - 2.0 * t)
        '''
        r = (1 - (1 - 2 * t) ** 2) ** 0.5
        return alpha_1 * (1.0 - r) + alpha_2 * r

    lr = scheduler(i / n_iters)
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr

    return lr

In [None]:
def set_model_FGE_modified():
  model = set_model_FGE()

  model._adjust_lr = type(model._adjust_lr)(_adjust_lr_modeified, model)

  return model

In [None]:
def set_model_SSE():
  # Choose the Ensemble Method
  model = SnapshotEnsembleClassifier(
      estimator=resnet18(),
      # estimator_args={"bresnet18lock": BasicBlock, "num_blocks": [2, 2, 2, 2]},
      n_estimators=n_estimators,
      cuda=True,
  )

  # Set the Optimizer
  model.set_optimizer(
      "SGD", lr=lr, weight_decay=weight_decay, momentum=momentum
  )

  # Set the Scheduler
  model.set_scheduler("CosineAnnealingLR", T_max=epochs)

  return model

In [None]:
transformer = transforms.Compose([transforms.ToTensor()])

In [None]:
from sklearn.metrics import accuracy_score, recall_score, precision_score, roc_auc_score, average_precision_score, confusion_matrix

In [None]:
# reference from https://stackoverflow.com/questions/50666091/true-positive-rate-and-false-positive-rate-tpr-fpr-for-multi-class-data-in-py

def get_fpr(y_true, y_prediction):
  cnf_matrix = confusion_matrix(y_true, y_prediction)

  FP = cnf_matrix.sum(axis=0) - np.diag(cnf_matrix)  
  FN = cnf_matrix.sum(axis=1) - np.diag(cnf_matrix)
  TP = np.diag(cnf_matrix)
  TN = cnf_matrix.sum() - (FP + FN + TP)

  FP = FP.astype(float)
  FN = FN.astype(float)
  TP = TP.astype(float)
  TN = TN.astype(float)

  FPR = FP/(FP+TN)

  return FPR

In [None]:
# modified version of evaluate function of torchensemble

from torchensemble.utils.io import split_data_target

def evaluate_modified(self, test_loader, return_loss=False):
    """Docstrings decorated by downstream models."""
    self.eval()
    correct = 0
    total = 0
    criterion = nn.CrossEntropyLoss()
    loss = 0.0

    ### collects all data for prediction
    outputs = None
    predicts = None
    labels = None

    first = True
    start = time.time()
    for _, elem in enumerate(test_loader):
        data, target = split_data_target(elem, self.device)
        output = self.forward(*data)
        _, predicted = torch.max(output.data, 1)
        correct += (predicted == target).sum().item()
        total += target.size(0)
        loss += criterion(output, target)

        ### appends all data
        if first:
            class_num = torch.max(predicted) + 1
            first = False
            outputs = output
            predicts = predicted
            labels = target
        else:
            outputs = torch.cat((outputs, output))
            predicts = torch.cat((predicts, predicted))
            labels = torch.cat((labels, target))
    inference_time = time.time() - start

    ### calculate results
    outputs = outputs.cpu().data.numpy()
    outputs = outputs[:, :class_num]
    predicts = predicts.cpu().data.numpy()
    labels_one_hot = torch.nn.functional.one_hot(labels)#, labels.max()+1)
    #print(labels_one_hot)
    labels = labels.cpu().data.numpy()
    labels_one_hot = labels_one_hot.cpu().data.numpy()
    #print(outputs.shape, predicts.shape, labels.shape, labels_one_hot.shape)
    result = {}
    result['Accuracy'] = accuracy_score(labels, predicts)
    result['TPR'] = recall_score(labels, predicts, average='macro')
    #result['FPR'] = get_fpr(labels, predicts)
    result['FPR'] = 1 - result['TPR']
    result['Precision'] = precision_score(labels, predicts, average='macro')
    result['AUC'] = roc_auc_score(labels_one_hot, outputs, multi_class='ovr')
    result['PR-Curve'] = average_precision_score(labels_one_hot, outputs, average='macro')
    result['Inference Time'] = inference_time / len(test_loader) * 1000
    #print(len(test_loader))

    '''
    acc = 100 * correct / total
    loss /= len(test_loader)

    if return_loss:
        return acc, float(loss)

    return acc
    '''

    return result

In [None]:
def eval_10cv(model_type, dataset_type, dataset_split=0):
  df = pd.DataFrame(columns=['Dataset Name', 'Algorithm Name', 'Cross Validation', 'Accuracy', 'TPR', 'FPR', 'Precision', 'AUC', 'PR-Curve', 'Training Time', 'Inference Time'])

  if dataset_type == 'CIFAR10':
    dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transformer)
  elif dataset_type == 'CIFAR100':
    dataset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transformer)
  elif dataset_type == 'ImageNet':
    dataset = datasets.ImageNet(data_dir, split='train', download=True, transform=transformer)
  elif dataset_type == 'MNIST':
    dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transformer)

  assert 0 <= dataset_split and dataset_split <= 4
  #split_index = np.array(range(dataset_split * len(dataset) // 4, (dataset_split+1) * len(dataset) // 4))
  split_index = random.sample(range(len(dataset)), len(dataset))
  split_index = split_index[dataset_split * len(dataset) // 5:(dataset_split+1) * len(dataset) // 5]
  dataset = Subset(dataset, split_index)

  kf_ex = KFold(n_splits=10)
  all_index = range(len(dataset))

  cv_num = 1
  for train_index, test_index in kf_ex.split(all_index):
    if model_type == 'FGE':
      model = set_model_FGE()
    elif model_type == 'FGEm':
      model = set_model_FGE_modified()
    elif model_type == 'SSE':
      model = set_model_SSE()
    else:
      print('wrong argument: model_type')
    
    model.evaluate = evaluate_modified

    train_dataset = Subset(dataset, train_index)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    # validation_dataset = Subset(dataset, validation_index)
    # validation_loader = DataLoader(validation_dataset, batch_size, shuffle=True)
    test_dataset = Subset(dataset, test_index)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=True)

    start = time.time()
    model.fit(
      train_loader,
      epochs=epochs,
      # test_loader=validation_loader,
    )
    train_time  = time.time() - start

    result = model.evaluate(model, test_loader)
    result['Training Time'] = train_time
    result['Dataset Name'] = dataset_type + '_' + str(dataset_split + 1)
    result['Algorithm Name'] = model_type
    result['Cross Validation'] = cv_num
    print(result)
    df = df.append(result, ignore_index=True)

    cv_num += 1

  return df

In [None]:
# Run with split data
model_type = 'FGE'          # set model
dataset_type = 'CIFAR10'    # set dataset
data_split = 0              # set offset of the data (0-4)
df = eval_10cv(model_type, dataset_type, data_split)
df.to_csv(dataset_type + '_' + model_type + '_' + str(data_split) + '.csv')

In [None]:
# Run with entire data
model_type = 'FGE'          # set model
dataset_type = 'CIFAR100'   # set dataset
for data_split in range(5):
  df = eval_10cv(model_type, dataset_type, data_split)
  df.to_csv(dataset_type + '_' + model_type + '_' + str(data_split) + '.csv')

In [None]:
df

In [None]:
# when execution stopped, modeify this

def eval_10cv(model_type, dataset_type, dataset_split=0):
  df = pd.DataFrame(columns=['Dataset Name', 'Algorithm Name', 'Cross Validation', 'Accuracy', 'TPR', 'FPR', 'Precision', 'AUC', 'PR-Curve', 'Training Time', 'Inference Time'])

  if dataset_type == 'CIFAR10':
    dataset = datasets.CIFAR10(data_dir, train=True, download=True, transform=transformer)
  elif dataset_type == 'CIFAR100':
    dataset = datasets.CIFAR100(data_dir, train=True, download=True, transform=transformer)
  elif dataset_type == 'ImageNet':
    dataset = datasets.ImageNet(data_dir, split='train', download=True, transform=transformer)
  elif dataset_type == 'MNIST':
    dataset = datasets.MNIST(data_dir, train=True, download=True, transform=transformer)

  assert 0 <= dataset_split and dataset_split <= 4
  #split_index = np.array(range(dataset_split * len(dataset) // 4, (dataset_split+1) * len(dataset) // 4))
  split_index = random.sample(range(len(dataset)), len(dataset))
  split_index = split_index[dataset_split * len(dataset) // 5:(dataset_split+1) * len(dataset) // 5]
  dataset = Subset(dataset, split_index)

  kf_ex = KFold(n_splits=10)
  all_index = range(len(dataset))

  cv_num = 1
  for train_index, test_index in kf_ex.split(all_index):
    if cv_num <= 7:
      rlist = [{'Accuracy': 0.498, 'TPR': 0.5015319855358373, 'FPR': 0.49846801446416267, 'Precision': 0.5000659568832656, 'AUC': 0.8892649826662705, 'PR-Curve': 0.546320074481208, 'Inference Time': 274.2760181427002, 'Training Time': 1043.5634968280792, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 1},
               {'Accuracy': 0.556, 'TPR': 0.5575471448862332, 'FPR': 0.44245285511376675, 'Precision': 0.5598966906084876, 'AUC': 0.9071498228363295, 'PR-Curve': 0.5964666483979707, 'Inference Time': 270.8180546760559, 'Training Time': 1044.0145936012268, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 2},
               {'Accuracy': 0.547, 'TPR': 0.5429857096417257, 'FPR': 0.45701429035827434, 'Precision': 0.5431857477841769, 'AUC': 0.9016243399998078, 'PR-Curve': 0.583741794720911, 'Inference Time': 268.0376172065735, 'Training Time': 1042.560531616211, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 3},
               {'Accuracy': 0.524, 'TPR': 0.5229544236788614, 'FPR': 0.47704557632113864, 'Precision': 0.5291013122659229, 'AUC': 0.8895413867011539, 'PR-Curve': 0.5628302842973161, 'Inference Time': 272.3956108093262, 'Training Time': 1043.731341123581, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 4},
               {'Accuracy': 0.573, 'TPR': 0.5691689652587358, 'FPR': 0.43083103474126416, 'Precision': 0.5717630822138045, 'AUC': 0.9132241303668558, 'PR-Curve': 0.6119640876794608, 'Inference Time': 268.7094211578369, 'Training Time': 1043.161875963211, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 5},
               {'Accuracy': 0.543, 'TPR': 0.5462239083448645, 'FPR': 0.45377609165513555, 'Precision': 0.5413901388199379, 'AUC': 0.895189039275121, 'PR-Curve': 0.574292156360536, 'Inference Time': 271.90470695495605, 'Training Time': 1044.248327255249, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 6},
               {'Accuracy': 0.553, 'TPR': 0.5596245081520613, 'FPR': 0.44037549184793867, 'Precision': 0.5554395041994891, 'AUC': 0.9005340104482151, 'PR-Curve': 0.5809220223042436, 'Inference Time': 269.21379566192627, 'Training Time': 1040.0652377605438, 'Dataset Name': 'CIFAR10_4', 'Algorithm Name': 'FGEm', 'Cross Validation': 7}

      ]
      df = df.append(rlist[cv_num - 1], ignore_index=True)
      print('check:', rlist[cv_num - 1])
      cv_num += 1
      continue


    if model_type == 'FGE':
      model = set_model_FGE()
    elif model_type == 'FGEm':
      model = set_model_FGE_modified()
    elif model_type == 'SSE':
      model = set_model_SSE()
    else:
      print('wrong argument: model_type')
    
    model.evaluate = evaluate_modified

    train_dataset = Subset(dataset, train_index)
    train_loader = DataLoader(train_dataset, batch_size, shuffle=True)
    # validation_dataset = Subset(dataset, validation_index)
    # validation_loader = DataLoader(validation_dataset, batch_size, shuffle=True)
    test_dataset = Subset(dataset, test_index)
    test_loader = DataLoader(test_dataset, batch_size, shuffle=True)

    start = time.time()
    model.fit(
      train_loader,
      epochs=epochs,
      # test_loader=validation_loader,
    )
    train_time  = time.time() - start

    result = model.evaluate(model, test_loader)
    result['Training Time'] = train_time
    result['Dataset Name'] = dataset_type + '_' + str(dataset_split + 1)
    result['Algorithm Name'] = model_type
    result['Cross Validation'] = cv_num
    print(result)
    df = df.append(result, ignore_index=True)

    cv_num += 1

  return df