In [5]:
import torch
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# MNIST dataset creation
from milearn.data.mnist import load_mnist, create_bags_or, create_bags_and, create_bags_xor, create_bags_sum
from milearn.network.module.utils import set_seed

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from milearn.wrapper import InstanceWrapper, BagWrapper
from milearn.network.regressor import InstanceNetworkRegressor, BagNetworkRegressor
from milearn.network.classifier import InstanceNetworkClassifier, BagNetworkClassifier
from milearn.network.regressor import (AttentionNetworkRegressor,
                                       TempAttentionNetworkRegressor,
                                       GatedAttentionNetworkRegressor,
                                       MultiHeadAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

from milearn.network.classifier import (AttentionNetworkClassifier,
                                        TempAttentionNetworkClassifier,
                                        GatedAttentionNetworkClassifier,
                                        MultiHeadAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

# Utils
from sklearn.metrics import r2_score, accuracy_score
from sklearn.model_selection import train_test_split
from milearn.preprocessing import BagMinMaxScaler

# Prediction visualisation
from milearn.data.mnist import visualize_bag_with_weights

In [13]:
def accuracy_metric(y_true, y_pred, task=None):
    if task == "classification":
        return accuracy_score(y_true, y_pred)
    elif task == "regression":
        return r2_score(y_true, y_pred)

In [14]:
TASK = "regression"
# TASK = "classification"

In [7]:
bag_size = 10
num_bags = 1000

data, targets = load_mnist()

if TASK == "classification":
    bags, labels, key = create_bags_or(data, targets, bag_size=bag_size, num_bags=num_bags, key_digit=3, key_instances_per_bag=1, random_state=42)
elif TASK == "regression":
    bags, labels, key = create_bags_sum(data, targets, bag_size=bag_size, num_bags=num_bags, random_state=42)

In [8]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [9]:
regressor_list = [
                   ("MeanInstanceWrapperRegressor", InstanceWrapper(estimator=RandomForestRegressor(), pool="mean")), 
                   ("MaxInstanceWrapperRegressor", InstanceWrapper(RandomForestRegressor(), pool="max")), 
                   ("MeanBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="mean")), 
                   ("MaxBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="max")), 
                   ("MinBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="min")), 
                   ("ExtremeBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="extreme")),
                   ("MeanInstanceNetworkRegressor", InstanceNetworkRegressor(**network_hparams, pool="mean")),
                   ("MaxInstanceNetworkRegressor", InstanceNetworkRegressor(**network_hparams, pool="max")),
                   ("MeanBagNetworkRegressor", BagNetworkRegressor(**network_hparams, pool="mean")),
                   ("MaxBagNetworkRegressor", BagNetworkRegressor(**network_hparams, pool="max")),
                   ("AttentionNetworkRegressor", AttentionNetworkRegressor(**network_hparams)),
                   ("TempAttentionNetworkRegressor", AttentionNetworkRegressor(**network_hparams)),
                   ("GatedAttentionNetworkRegressor", GatedAttentionNetworkRegressor(**network_hparams)),
                   ("MultiHeadAttentionNetworkRegressor", SelfAttentionNetworkRegressor(**network_hparams)),
                   ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor(**network_hparams)),
                   ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor(**network_hparams)),
                   ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor(**network_hparams))
              ]

classifier_list = [
                   ("MeanInstanceWrapperClassifier", InstanceWrapper(estimator=RandomForestClassifier(), pool="mean")), 
                   ("MaxInstanceWrapperClassifier", InstanceWrapper(RandomForestClassifier(), pool="max")), 
                   ("MeanBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="mean")), 
                   ("MaxBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="max")), 
                   ("MinBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="min")), 
                   ("ExtremeBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="extreme")),
                   ("MeanInstanceNetworkClassifier", InstanceNetworkClassifier(**network_hparams, pool="mean")),
                   ("MaxInstanceNetworkClassifier", InstanceNetworkClassifier(**network_hparams, pool="max")),
                   ("MeanBagNetworkClassifier", BagNetworkClassifier(**network_hparams, pool="mean")),
                   ("MaxBagNetworkClassifier", BagNetworkClassifier(**network_hparams, pool="max")),
                   ("AttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)),
                   ("TempAttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)),
                   ("GatedAttentionNetworkClassifier", GatedAttentionNetworkClassifier(**network_hparams)),
                   ("MultiHeadAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)),
                   ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)),
                   ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier(**network_hparams)),
                   ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier(**network_hparams))
              ]

In [11]:
x_train, x_test, y_train, y_test, key_train, key_test = train_test_split(bags, labels, key, random_state=42)
# 
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [15]:
if TASK == "regression":
    method_list = regressor_list
elif TASK == "classification":
    method_list = classifier_list

res_df = pd.DataFrame()
for method_name, model in method_list:
    model.fit(x_train_scaled, y_train)

    if TASK == "regression":
        y_pred = model.predict(x_test_scaled)
    elif TASK == "classification":
        y_prob = model.predict(x_test_scaled)
        y_pred = np.where(y_prob > 0.5, 1, 0)
    
    res_df.loc[method_name, "ACC"] = accuracy_metric(y_test, y_pred, task=TASK)

In [17]:
res_df.sort_values(by="ACC", ascending=False)

Unnamed: 0,ACC
MeanInstanceNetworkRegressor,0.54964
MeanBagNetworkRegressor,0.544492
MeanBagWrapperRegressor,0.450058
AttentionNetworkRegressor,0.406646
GatedAttentionNetworkRegressor,0.376983
TempAttentionNetworkRegressor,0.351514
DynamicPoolingNetworkRegressor,0.343001
MaxBagWrapperRegressor,0.216426
ExtremeBagWrapperRegressor,0.201203
MeanInstanceWrapperRegressor,0.147412
