In [1]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision.transforms import v2
from image_patcher import ImagePatcher
from dataset import MILDataset
from data_utils import collate_fn
from model import AttentionMILModel

In [2]:
dataset_path = "/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/datasets/eyepacs-aptos-messidor-diabetic-retinopathy-original-preprocessed-color-enhancement/train/multiclass"

output_dataset_path = "data/eyepacs-aptos"
features_output = os.path.join(output_dataset_path, "features")
labels_output = os.path.join(output_dataset_path, "labels")

In [3]:
BATCH_SIZE = 32

In [4]:
transform = v2.Compose([
    v2.ToTensor()
])



In [5]:
patcher = ImagePatcher(patch_size=32, empty_thresh=0.1)

In [6]:
mil_dataset = MILDataset(dataset_path, patcher)

In [7]:
mil_dataloader = DataLoader(mil_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn, shuffle=True)

# Inference

In [8]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [9]:
model = AttentionMILModel(att_dim=128)
model.to(device)



AttentionMILModel(
  (resnet): ResNet(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, 

In [10]:
features, labels, masks, bags_length = next(iter(mil_dataloader))

In [11]:
features = features.to(device)
masks = masks.to(device)

In [12]:
output = model(features, masks, bags_length)

In [13]:
output.shape

torch.Size([32])