In [2]:
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader



from slide_util import slide_to_tiles

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
from sklearn.metrics import precision_score, recall_score, roc_auc_score


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
BASE_CHECKPOINT = "hf_hub:Snarcy/RedDino-base"

In [4]:
class GatedAttentionMIL_RED(nn.Module):
    def __init__(self, base_checkpoint, M=500, L=128, attention_branches=1):
        super().__init__()
        self.M, self.L, self.B = M, L, attention_branches

        # Backbone returns a per-tile embedding (D=768)
        self.backbone = timm.create_model(base_checkpoint, pretrained=True, num_classes=0)  # no classifier head
        self.embed_dim = 768

        # Project to attention space
        self.feature_projector = nn.Sequential(
            nn.Linear(768, M),   # <<â€” your 768 here
            nn.ReLU(inplace=True),
        )

        # Gated attention
        self.attention_V = nn.Sequential(nn.Linear(M, L), nn.Tanh())
        self.attention_U = nn.Sequential(nn.Linear(M, L), nn.Sigmoid())
        self.attention_w = nn.Linear(L, self.B)

        # Bag classifier
        self.classifier = nn.Sequential(
            nn.Linear(M * self.B, 1),
            nn.Sigmoid()
        )

    def forward(self, tiles, mask=None):
        """
        tiles: [K, 3, 244, 244]
        mask:  [K] boolean (optional)
        """
        H = self.backbone(tiles)          # [K, 768]
        H = self.feature_projector(H)     # [K, M]

        A_V = self.attention_V(H)         # [K, L]
        A_U = self.attention_U(H)         # [K, L]
        A = self.attention_w(A_V * A_U).transpose(0, 1)  # [B, K]

        if mask is not None:
            A = A.masked_fill(~mask.unsqueeze(0), float('-inf'))

        A = F.softmax(A, dim=1)           # over tiles
        Z = A @ H                         # [B, M]
        Y_prob = self.classifier(Z.reshape(1, -1))  # [1,1]
        Y_hat = (Y_prob >= 0.5).float()

        return Y_prob, Y_hat, A

In [15]:
bag_1 = slide_to_tiles(np.asarray(ImageOps.exif_transpose(Image.open("C:\\Code\\DL\\bbosis\\data\\Training data\\79.jpg")).convert("RGB")), 224, 0, -1)
bag_1[0].shape
len(bag_1)

144

In [16]:
tns = torch.tensor(bag_1)
tns.shape

torch.Size([144, 224, 224, 3])

In [17]:
model = GatedAttentionMIL_RED(BASE_CHECKPOINT)
tiles = torch.tensor(bag_1).permute(0,3,1,2).float()  # [K, 3, 224, 224]
tiles.shape
with torch.no_grad():
    Y_prob, Y_hat, A = model(tiles)
Y_prob, Y_hat, A.shape

(tensor([[0.5377]]), tensor([[1.]]), torch.Size([1, 144]))

In [18]:
A

tensor([[0.0070, 0.0069, 0.0069, 0.0070, 0.0069, 0.0070, 0.0069, 0.0069, 0.0070,
         0.0069, 0.0070, 0.0070, 0.0070, 0.0069, 0.0069, 0.0070, 0.0069, 0.0070,
         0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069,
         0.0070, 0.0069, 0.0069, 0.0070, 0.0069, 0.0069, 0.0070, 0.0070, 0.0069,
         0.0070, 0.0069, 0.0070, 0.0069, 0.0069, 0.0069, 0.0070, 0.0070, 0.0070,
         0.0070, 0.0070, 0.0069, 0.0070, 0.0070, 0.0070, 0.0069, 0.0070, 0.0069,
         0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0070, 0.0069, 0.0070,
         0.0069, 0.0070, 0.0069, 0.0069, 0.0070, 0.0069, 0.0070, 0.0069, 0.0069,
         0.0069, 0.0070, 0.0069, 0.0069, 0.0070, 0.0069, 0.0069, 0.0069, 0.0070,
         0.0070, 0.0070, 0.0070, 0.0069, 0.0070, 0.0070, 0.0070, 0.0069, 0.0069,
         0.0070, 0.0069, 0.0069, 0.0069, 0.0069, 0.0069, 0.0070, 0.0069, 0.0070,
         0.0070, 0.0069, 0.0069, 0.0070, 0.0070, 0.0069, 0.0070, 0.0069, 0.0069,
         0.0069, 0.0069, 0.0

In [5]:
class ImageBagDataset(Dataset):
    def __init__(self, csv_file):
        self.data = pd.read_csv(csv_file)
        self.path = "C:\\Code\\DL\\bbosis\\data\\Training data\\"

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        bag = slide_to_tiles(np.asarray(ImageOps.exif_transpose(Image.open(f"{self.path}{row['img_pth']}.jpg")).convert("RGB")), 224, 0, -1)

        label = row['label']
        id_ = row['img_pth']
        bag_tensor = torch.tensor(bag).permute(0,3,1,2).float()  # [K, 3, 224, 224]
        return bag_tensor, label, id_

In [7]:
train_dataset = ImageBagDataset("C:\\Code\\DL\\bbosis\\data\\train_set.csv")
test_dataset = ImageBagDataset("C:\\Code\\DL\\bbosis\\data\\test_set.csv")

bag, label, id_ = train_dataset[0]
bag.shape, label, id_

  bag_tensor = torch.tensor(bag).permute(0,3,1,2).float()  # [K, 3, 224, 224]


(torch.Size([144, 3, 224, 224]), np.int64(0), np.int64(953))

In [1]:
import math

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
num_epochs = 10
learning_rate = 1e-4

model = GatedAttentionMIL_RED(BASE_CHECKPOINT).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.BCELoss()

def single_slide_collate(batch):
    """Ensure each batch contains a single slide bag without extra stacking."""
    return batch[0]

def normalize_slide_id(slide_id):
    if isinstance(slide_id, (list, tuple)):
        slide_id = slide_id[0]
    if hasattr(slide_id, "item"):
        slide_id = slide_id.item()
    return str(slide_id)

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=single_slide_collate)
val_loader = DataLoader(test_dataset, batch_size=4, shuffle=False, collate_fn=single_slide_collate)

history = []
misclassified_history = []

for epoch in range(num_epochs):
    model.train()
    train_losses, train_probs, train_preds, train_labels = [], [], [], []
    train_misclassified = []

    for bag, label, slide_id in train_loader:
        bag = bag.to(device).float() / 255.0
        label_tensor = torch.as_tensor(label, dtype=torch.float32, device=device).view(1, 1)

        optimizer.zero_grad()
        y_prob, _, _ = model(bag)
        loss = criterion(y_prob, label_tensor)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())
        prob = y_prob.detach().cpu().item()
        pred = int(prob >= 0.5)
        true = int(label_tensor.detach().cpu().item())
        train_probs.append(prob)
        train_preds.append(pred)
        train_labels.append(true)

        if pred != true:
            train_misclassified.append(normalize_slide_id(slide_id))

    train_loss = float(np.mean(train_losses)) if train_losses else float('nan')
    train_precision = precision_score(train_labels, train_preds, zero_division=0) if train_labels else 0.0
    train_recall = recall_score(train_labels, train_preds, zero_division=0) if train_labels else 0.0
    try:
        train_auc = roc_auc_score(train_labels, train_probs)
    except ValueError:
        train_auc = float('nan')

    model.eval()
    val_losses, val_probs, val_preds, val_labels = [], [], [], []
    val_misclassified = []
    with torch.no_grad():
        for bag, label, slide_id in val_loader:
            bag = bag.to(device).float() / 255.0
            label_tensor = torch.as_tensor(label, dtype=torch.float32, device=device).view(1, 1)

            y_prob, _, _ = model(bag)
            loss = criterion(y_prob, label_tensor)

            val_losses.append(loss.item())
            prob = y_prob.detach().cpu().item()
            pred = int(prob >= 0.5)
            true = int(label_tensor.detach().cpu().item())
            val_probs.append(prob)
            val_preds.append(pred)
            val_labels.append(true)

            if pred != true:
                val_misclassified.append(normalize_slide_id(slide_id))

    val_loss = float(np.mean(val_losses)) if val_losses else float('nan')
    val_precision = precision_score(val_labels, val_preds, zero_division=0) if val_labels else 0.0
    val_recall = recall_score(val_labels, val_preds, zero_division=0) if val_labels else 0.0
    try:
        val_auc = roc_auc_score(val_labels, val_probs)
    except ValueError:
        val_auc = float('nan')

    history.append({
        "epoch": epoch + 1,
        "train": {"loss": train_loss, "precision": train_precision, "recall": train_recall, "auc": train_auc},
        "val": {"loss": val_loss, "precision": val_precision, "recall": val_recall, "auc": val_auc},
    })
    misclassified_history.append({
        "epoch": epoch + 1,
        "train": train_misclassified,
        "val": val_misclassified,
    })

    def fmt(value):
        return f"{value:.4f}" if not (isinstance(value, float) and math.isnan(value)) else "nan"

    print(f"Epoch {epoch + 1}/{num_epochs}")
    print(
        f"  Train -> Loss: {fmt(train_loss)} | Precision: {fmt(train_precision)} | Recall: {fmt(train_recall)} | AUC: {fmt(train_auc)}"
    )
    print(
        f"  Val   -> Loss: {fmt(val_loss)} | Precision: {fmt(val_precision)} | Recall: {fmt(val_recall)} | AUC: {fmt(val_auc)}"
    )
    print(
        f"  Misclassified slides -> Train: {len(train_misclassified)} | Val: {len(val_misclassified)}"
    )

history, misclassified_history


NameError: name 'torch' is not defined