In [1]:
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)

import time
import torch
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# MNIST dataset creation
from milearn.data.mnist import load_mnist, create_bags_or, create_bags_and, create_bags_xor, create_bags_reg

# Preprocessing
from milearn.preprocessing import BagMinMaxScaler

# Network hparams
from milearn.network.module.hopt import DEFAULT_PARAM_GRID

# MIL wrappers
from milearn.network.regressor import BagWrapperMLPNetworkRegressor, InstanceWrapperMLPNetworkRegressor
from milearn.network.classifier import BagWrapperMLPNetworkClassifier, InstanceWrapperMLPNetworkClassifier

# MIL networks
from milearn.network.regressor import (InstanceNetworkRegressor,
                                       BagNetworkRegressor,
                                       AdditiveAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

from milearn.network.classifier import (InstanceNetworkClassifier,
                                        BagNetworkClassifier,
                                        AdditiveAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

# Utils
from sklearn.metrics import r2_score, accuracy_score
from sklearn.model_selection import train_test_split

# Prediction visualisation
from milearn.data.mnist import visualize_bag_with_weights

In [2]:
def accuracy_metric(y_true, y_pred, task=None):
    if task == "classification":
        return accuracy_score(y_true, y_pred)
    elif task == "regression":
        return r2_score(y_true, y_pred)

In [3]:
# TASK = "regression"
TASK = "classification"

In [4]:
bag_size = 10
num_bags = 10000

data, targets = load_mnist()

if TASK == "classification":
    bags, labels, key = create_bags_or(data, targets, bag_size=bag_size, num_bags=num_bags, 
                                       key_digit=3, key_instances_per_bag=1, random_state=42)
elif TASK == "regression":
    bags, labels, key = create_bags_reg(data, targets, bag_size=bag_size, num_bags=num_bags, 
                                        bag_agg="mean", random_state=42)

In [5]:
regressor_list = [

        # wrapper mil networks
        ("MeanBagWrapperMLPNetworkRegressor", BagWrapperMLPNetworkRegressor(pool="mean")),
        ("MeanInstanceWrapperMLPNetworkRegressor", InstanceWrapperMLPNetworkRegressor(pool="mean")),
    
        # classic mil networks
        ("MeanBagNetworkRegressor", BagNetworkRegressor(pool="mean")),
        ("MeanInstanceNetworkRegressor", InstanceNetworkRegressor(pool="mean")),

        # attention mil networks
        ("AdditiveAttentionNetworkRegressor", AdditiveAttentionNetworkRegressor()),
        ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor()),
        ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor()),

        # other mil networks
        ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor()),
    ]

classifier_list = [

        # wrapper mil networks
        ("MeanBagWrapperMLPNetworkClassifier", BagWrapperMLPNetworkClassifier(pool="mean")),
        ("MeanInstanceWrapperMLPNetworkClassifier", InstanceWrapperMLPNetworkClassifier(pool="mean")),
    
        # classic mil networks
        ("MeanBagNetworkClassifier", BagNetworkClassifier(pool="mean")),
        ("MeanInstanceNetworkClassifier", InstanceNetworkClassifier(pool="mean")),

        # attention mil networks
        ("AdditiveAttentionNetworkClassifier", AdditiveAttentionNetworkClassifier()),
        ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier()),
        ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier()),

        # other mil networks
        ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier()),
    ]

In [6]:
# train/test split
x_train, x_test, y_train, y_test, key_train, key_test = train_test_split(bags, labels, key, random_state=42)

# features scaling
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [7]:
DEFAULT_PARAM_GRID["hidden_layer_sizes"] = (2048, 1024, 512, 256, 128, 64)

In [8]:
if TASK == "regression":
    method_list = regressor_list
elif TASK == "classification":
    method_list = classifier_list

res_df = pd.DataFrame()
n_methods = len(method_list)

for i, (method_name, model) in enumerate(method_list, start=1):
    print(f"[{i}/{n_methods}] Running {method_name}...")
    start_time = time.time()

    # model.hopt(x_train_scaled, y_train, param_grid=DEFAULT_PARAM_GRID, verbose=False)
    model.fit(x_train_scaled, y_train)

    if TASK == "regression":
        y_pred = model.predict(x_test_scaled)
    elif TASK == "classification":
        y_prob = model.predict(x_test_scaled)
        y_pred = np.where(y_prob > 0.5, 1, 0)

    acc = accuracy_metric(y_test, y_pred, task=TASK)
    res_df.loc[method_name, "Accuracy"] = acc

    elapsed_min = (time.time() - start_time) / 60
    print(f"    → Done. Accuracy = {acc:.2f}, Time: {elapsed_min:.2f} min")

[1/8] Running MeanBagWrapperMLPNetworkClassifier...
    → Done. Accuracy = 0.71, Time: 0.07 min
[2/8] Running MeanInstanceWrapperMLPNetworkClassifier...
    → Done. Accuracy = 0.69, Time: 0.60 min
[3/8] Running MeanBagNetworkClassifier...
    → Done. Accuracy = 0.97, Time: 0.12 min
[4/8] Running MeanInstanceNetworkClassifier...
    → Done. Accuracy = 0.97, Time: 0.11 min
[5/8] Running AdditiveAttentionNetworkClassifier...
    → Done. Accuracy = 0.97, Time: 0.14 min
[6/8] Running SelfAttentionNetworkClassifier...
    → Done. Accuracy = 0.97, Time: 0.10 min
[7/8] Running HopfieldAttentionNetworkClassifier...
    → Done. Accuracy = 0.97, Time: 0.13 min
[8/8] Running DynamicPoolingNetworkClassifier...
    → Done. Accuracy = 0.96, Time: 0.13 min


In [9]:
res_df.sort_values(by="Accuracy", ascending=False)

Unnamed: 0,Accuracy
AdditiveAttentionNetworkClassifier,0.9716
HopfieldAttentionNetworkClassifier,0.9708
SelfAttentionNetworkClassifier,0.97
MeanInstanceNetworkClassifier,0.9664
MeanBagNetworkClassifier,0.9664
DynamicPoolingNetworkClassifier,0.9644
MeanBagWrapperMLPNetworkClassifier,0.712
MeanInstanceWrapperMLPNetworkClassifier,0.6904
