### Introduction

Each bag contains a fixed number of MNIST digits (e.g., 10). Each digit belongs to a class (0–9), and the bag label is determined based on the digits using one of the specific logical rules.

**Instance:** One MNIST digit image.

**Bag**: A collection of digits (e.g., a list of 10 MNIST digits).

**Label:** Positive, if, for example, a bag contains some key digit (e.g., 3), otherwise negative.

**Key instance:** A key digit image.

### Logic-Based Labeling Rules

🟢 **OR Setting: Easy** 

The bag is positive if at least one instance belongs to a target class (e.g., digit 3).

* Bag Label = 1 if any digit in the bag is a 3

* Key instance = any one of the 3s (all of them are “valid” keys)

**Intuition:** Easy, similar to classic MIL assumption.

**Example:**
Bag = [1, 5, 7, 3, 0] → Label = 1 / Key instance: 7
  
🔴 **AND Setting: Medium**

The bag is positive only if all instances satisfy a condition (e.g., 3 and 7 should be presented at the same time).  
* Bag Label = 1 if 3 and 7 are in the bug at the same time.  
* Key instance = 3 and 7; if only one is presented, label = 0.

**Intuition:** Harder than OR, no single key; the absence of a “bad” instance is important

**Example:**  
Bag = [0, 3, 4, 6, 7] → Label = 1 / Key instances: 3 and 7  
Bag = [0, 3, 3, 6, 8] → Label = 0

🟡 **XOR Setting: Hard**

The bag is positive if an odd number of target class instances are present.

* Bag Label = 1 if there's 3 but not 7.  
* Key instance = The exact configuration matters — it’s not about presence alone, but count.

**Intuition:** Very hard — label depends on interactions between instances

**Example:**  
Bag = [1, 3, 7, 2] → Label = 0 (3 and 7 at the same time)  
Bag = [1, 3, 6, 2] → Label = 1 (only 3 but not 7) / Key instances: 3  
Bag = [1, 5, 7, 2] → Label = 1 (only 7 but not 3) / Key instances: 7    

In [None]:
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)

import time
import torch
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# MNIST dataset creation
from milearn.data.mnist import load_mnist, create_bags_or, create_bags_and, create_bags_xor, create_bags_reg

# Preprocessing
from milearn.preprocessing import BagMinMaxScaler

# Network hparams
from milearn.network.module.hopt import DEFAULT_PARAM_GRID

# MIL wrappers
from milearn.network.regressor import BagWrapperMLPNetworkRegressor, InstanceWrapperMLPNetworkRegressor
from milearn.network.classifier import BagWrapperMLPNetworkClassifier, InstanceWrapperMLPNetworkClassifier

# MIL networks
from milearn.network.classifier import (InstanceNetworkClassifier,
                                        BagNetworkClassifier,
                                        AdditiveAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

# Utils
from sklearn.metrics import r2_score, accuracy_score
from sklearn.model_selection import train_test_split

# Prediction visualisation
from milearn.data.mnist import visualize_bag_with_weights

### Key Instance Detection Accuracy Functions

These functions evaluate how well a model's predicted attention weights or importance scores match the true key instance indices in positive bags, according to the logic used during bag labeling.

Each function returns a fraction of correctly identified positive bags, based on whether at least one (OR), all (AND), or exactly one (XOR) of the true key instances are recovered among the top-k predicted instances.

In [None]:
def kid_accuracy_or(key_indices, predicted_weights, bag_labels, top_k=1):
    
    correct, total = 0, 0
    for weights, keys, label in zip(predicted_weights, key_indices, bag_labels):
        if label == 1:
            pred_top_k = np.argsort(weights)[-top_k:]
            if any(idx in pred_top_k for idx in keys):
                correct += 1
            total += 1

    return correct / total if total > 0 else 0.0

def kid_accuracy_and(key_indices, predicted_weights, bag_labels, top_k=2):
    
    correct, total = 0, 0
    for weights, keys, label in zip(predicted_weights, key_indices, bag_labels):
        if label == 1:
            pred_top_k = set(np.argsort(weights)[-top_k:])
            if all(k in pred_top_k for k in keys):
                correct += 1
            total += 1

    return correct / total if total > 0 else 0.0

def kid_accuracy_xor(key_indices, predicted_weights, bag_labels, top_k=1):
    
    correct, total = 0, 0
    for weights, keys, label in zip(predicted_weights, key_indices, bag_labels):
        if label == 1:
            pred_top_k = np.argsort(weights)[-top_k:]
            match_count = sum(idx in pred_top_k for idx in keys)
            if match_count == 1:
                correct += 1
            total += 1

    return correct / total if total > 0 else 0.0

### 1. Create MNIST dataset

In [None]:
bag_size = 10
num_bags = 10000
setting_type = "or"

data, targets = load_mnist()
if setting_type == "or":
    bags, labels, key = create_bags_or(data, targets, bag_size=bag_size, num_bags=num_bags, key_digit=3, key_instances_per_bag=1, random_state=42)
    kid_accuracy = kid_accuracy_or
elif setting_type == "and":
    bags, labels, key = create_bags_and(data, targets,  bag_size=bag_size, num_bags=num_bags, key_digits=(3, 7), random_state=42)
    kid_accuracy = kid_accuracy_and
elif setting_type == "xor":
    bags, labels, key = create_bags_xor(data, targets, bag_size=bag_size, num_bags=num_bags, key_digits=(3, 7), random_state=42)
    kid_accuracy = kid_accuracy_xor

### 2. Build model

In [None]:
# train/test split
x_train, x_test, y_train, y_test, key_train, key_test = train_test_split(bags, labels, key, random_state=42)

# features scaling
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [None]:
model = AdditiveAttentionNetworkClassifier(verbose=True)
# model.hopt(x_train_scaled, y_train, param_grid=DEFAULT_PARAM_GRID, verbose=False)
model.fit(x_train_scaled, y_train)

In [None]:
# predict labels
y_prob = model.predict(x_test_scaled)
y_pred = np.where(y_prob > 0.5, 1, 0)

# predict instance weights
w_pred = model.get_instance_weights(x_test_scaled)
w_pred = [w.flatten() for w in w_pred]

In [None]:
print(f"Classification accuracy: {accuracy_score(y_test, y_pred):.2f}")
print(f"KID accuracy: {kid_accuracy(key_test, w_pred, y_test):.2f}")

In [None]:
# bags with correct positive class prediction (use these indexes for visualization)
positive_bags = list(np.where((y_pred == 1) & (np.array(y_test) == 1))[0])

In [None]:
N = positive_bags[0]
visualize_bag_with_weights(x_test[N], w_pred[N], 
                           title=f"Bag {N}\nPredicted label:{y_pred[N].item()}\nTrue label: {y_test[N]}")

### 3. KID benchmark

In [None]:
classifier_list = [
    
        # attention mil networks
        ("AdditiveAttentionNetworkClassifier", AdditiveAttentionNetworkClassifier()),
        ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier()),
        ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier()),

        # other mil networks
        ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier()),
    ]

In [None]:
DEFAULT_PARAM_GRID["hidden_layer_sizes"] = (2048, 1024, 512, 256, 128, 64)

In [None]:
bag_size = 10
num_bags = 10000

data, targets = load_mnist()

res_df = pd.DataFrame()
for task_idx, task in enumerate(["or", "and", "xor"], 1):
    print(f"\n[Task {task_idx}/3] Starting task: '{task}'")

    # create dataset
    if task == "or":
        bags, labels, key = create_bags_or(data, targets, bag_size=bag_size, num_bags=num_bags, key_digit=3, key_instances_per_bag=1)
        kid_accuracy = kid_accuracy_or
    elif task == "and":
        bags, labels, key = create_bags_and(data, targets,  bag_size=bag_size, num_bags=num_bags, key_digits=(3, 7))
        kid_accuracy = kid_accuracy_and
    elif task == "xor":
        bags, labels, key = create_bags_xor(data, targets, bag_size=bag_size, num_bags=num_bags, key_digits=(3, 7))
        kid_accuracy = kid_accuracy_xor

    # prepare data
    x_train, x_test, y_train, y_test, key_train, key_test = train_test_split(bags, labels, key, random_state=42)
    scaler = BagMinMaxScaler()
    scaler.fit(x_train)
    x_train_scaled, x_test_scaled = scaler.transform(x_train), scaler.transform(x_test)

    # build models
    for model_idx, (name, model) in enumerate(classifier_list, 1):
        print(f"  [Model {model_idx}/{len(classifier_list)}] Training model: '{name}' for task '{task}'")

        # train model
        # model.hopt(x_train_scaled, y_train, param_grid=DEFAULT_PARAM_GRID, verbose=False)
        model.fit(x_train_scaled, y_train)
        # predict
        y_prob = model.predict(x_test_scaled)
        y_pred = np.where(y_prob > 0.5, 1, 0)
        w_pred = model.get_instance_weights(x_test_scaled)
        #
        res_df.loc[name, f"PRED_{task}"] = accuracy_score(y_test, y_pred)
        res_df.loc[name, f"KID_{task}"] = kid_accuracy(key_test, w_pred, y_test)

print("\nAll tasks and models completed.")

In [None]:
res_df.round(2)