# Slide-level Classification with MIL (Intro)

Multiple Instance Learning (MIL) aggregates patch features to predict slide-level labels. Here we outline a simple attention-based MIL pipeline with PyTorch.

In [None]:
# Toy example: attention MIL head over patch features
import torch, torch.nn as nn, torch.nn.functional as F
class AttnMIL(nn.Module):
    def __init__(self, d_in=2048, d_attn=256, n_classes=2):
        super().__init__()
        self.attn = nn.Sequential(nn.Linear(d_in, d_attn), nn.Tanh(), nn.Linear(d_attn, 1))
        self.clf = nn.Linear(d_in, n_classes)
    def forward(self, feats):  # feats: [N, d_in]
        a = self.attn(feats)            # [N,1]
        w = torch.softmax(a, dim=0)      # attention weights
        z = (w * feats).sum(dim=0)       # bag embedding
        logits = self.clf(z)
        return logits, w.squeeze(-1)
# Demo with random features
N, D = 64, 2048
feats = torch.randn(N, D)
model = AttnMIL(D)
logits, weights = model(feats)
logits, weights.shape