In [1]:
from image_patcher import ImagePatcher
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import v2
import os
from typing import Tuple
from torchvision.io import read_image

In [2]:
dataset_path = "/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/datasets/eyepacs-aptos-messidor-diabetic-retinopathy-original-preprocessed-color-enhancement/train/multiclass"

output_dataset_path = "data/eyepacs-aptos"
features_output = os.path.join(output_dataset_path, "features")
labels_output = os.path.join(output_dataset_path, "labels")

In [3]:
transform = v2.Compose([
    v2.ToTensor()
])



In [4]:
dataset = ImageFolder(dataset_path, transform=transform)

In [5]:
patcher = ImagePatcher(patch_size=32, empty_thresh=0.1)

In [6]:
image = dataset[1][0]

c, h, w = image.shape
patcher.get_tiles(h, w)
instances, instances_idx, instances_cords = patcher.convert_img_to_bag(image)

reconstructed_image = patcher.reconstruct_image_from_patches(instances, instances_idx, (3, 640, 640))

reconstructed_image = reconstructed_image.permute(1, 2, 0).numpy()

In [7]:
class MILDataset(Dataset):
    def __init__(self, dataset_path: str, image_patcher: ImagePatcher, transform=None) -> None:
        super().__init__()
        if transform is None:
            transform = v2.Compose([
                v2.ToTensor()
            ])

        self.image_patcher = image_patcher
        self.img_folder_dataset = ImageFolder(dataset_path, transform=transform)

    def __len__(self):
        return len(self.img_folder_dataset)

    def __getitem__(self, index) -> Tuple:
        image, label = self.img_folder_dataset[index]
        c, h, w = image.shape
        self.image_patcher.get_tiles(h, w)
        instances, instances_idx, instances_cords = self.image_patcher.convert_img_to_bag(image)
        return instances, label

In [44]:
mil_dataset = MILDataset(dataset_path, patcher)



In [45]:
def collate_fn(batch):
    # Aquires important dimensions
    batch_size = len(batch)
    c, h, w = batch[0][0].shape[1:]
    max_bag_length = max([len(x) for x, y in batch])
    
    # Initializing placeholders for features and labels
    features = torch.zeros((batch_size*max_bag_length, c, h, w))
    labels = torch.zeros((len(batch)), dtype=torch.long)

    # Masking placeholder, mask = 1 for valid instances, 0 for padded instances
    masks = torch.zeros((batch_size*max_bag_length))

    # Empty image used for padding
    pad_image = torch.zeros((1, c, h, w))

    for i, (x, y) in enumerate(batch):
        n_instances, c, h, w = x.shape

        # Set features and labels
        features[i*max_bag_length:(i*max_bag_length+n_instances)] = x
        features[(i*max_bag_length+n_instances):(i+1)*max_bag_length] = pad_image.expand((max_bag_length-n_instances, c, h, w))

        masks[i*max_bag_length:(i*max_bag_length+n_instances)] = 1
        labels[i] = y

    return features, labels, masks


In [60]:
mil_dataloader = DataLoader(mil_dataset, batch_size=32, collate_fn=collate_fn, shuffle=True)

# Inference

In [47]:
from torchmil.nn import masked_softmax
from torchvision.models import resnet18, ResNet18_Weights


class ABMIL(torch.nn.Module):
    def __init__(self, att_dim):
        super().__init__()

        # Feature extractor
        self.resnet = resnet18(ResNet18_Weights)
        emb_dim = self.resnet.fc.in_features

        self.resnet.fc = torch.nn.Identity()


        self.fc1 = torch.nn.Linear(emb_dim, att_dim)
        self.fc2 = torch.nn.Linear(att_dim, 1)

        self.classifier = torch.nn.Linear(emb_dim, 1)

    def forward(self, X, mask, batch_size, return_att=False):
        bag_size = int(X.shape[0] / batch_size)

        # Process only instances that are not masked (i.e., valid instances, not padding)
        X = self.resnet(X[mask != 0])  # (batch_size * bag_size, emb_dim)

        # Put back the processed instances to their original positions, so that the shape is preserved (as if all instances, including padding, were processed)
        resnet_output = torch.zeros((batch_size * bag_size, X.shape[1]), device=X.device)
        resnet_output[mask != 0] = X
        X = resnet_output

        # Reshaping to separate bags from batches
        X = X.reshape((batch_size, bag_size, -1))  # (batch_size, bag_size, emb_dim)
        mask = mask.reshape((batch_size, bag_size))  # (batch_size, bag_size)

        H = torch.tanh(self.fc1(X))  # (batch_size, bag_size, att_dim)
        att = torch.sigmoid(self.fc2(H))  # (batch_size, bag_size, 1)

        att_s = masked_softmax(att, mask)  # (batch_size, bag_size, 1)
        # att_s = torch.nn.functional.softmax(att, dim=1)
        X = torch.bmm(att_s.transpose(1, 2), X).squeeze(1)  # (batch_size, emb_dim)
        y = self.classifier(X).squeeze(1)  # (batch_size,)
        if return_att:
            return y, att_s
        else:
            return y

In [48]:
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"

In [49]:
model = ABMIL(att_dim=128)
model.to(device)



ABMIL(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_runnin

In [61]:
features, labels, masks = next(iter(mil_dataloader))

In [62]:
features = features.to(device)
masks = masks.to(device)

In [53]:
output = model(features, masks, 32)

In [55]:
output.shape

torch.Size([32])