## Imports

In [7]:
import os
import glob
import xml.etree.ElementTree as ET
from PIL import Image
# Allow loading of very large images
Image.MAX_IMAGE_PIXELS = None
import tifffile

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms as T
from torchinfo import summary

## Dataset Loading

In [2]:
# Helper: load XML annotations
def load_xml_annotations(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    boxes, labels = [], []
    for obj in root.findall('object'):
        name = obj.find('name').text  # e.g. 'monocytes', 'lymphocytes'
        b = obj.find('bndbox')
        xmin, ymin = float(b.find('xmin').text), float(b.find('ymin').text)
        xmax, ymax = float(b.find('xmax').text), float(b.find('ymax').text)
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(1)  # background=0, MNL=1
    return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

In [3]:
class MNLDataset(Dataset):
    def __init__(self, img_dir, ann_dir, transforms=None):
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.tif')))
        self.ann_paths = [
            os.path.join(ann_dir, os.path.basename(p).replace('.tif', '.xml'))
            for p in self.img_paths
        ]
        self.transforms = transforms

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        # Read with tifffile to avoid decompression issues
        img_np = tifffile.imread(self.img_paths[idx])
        # Some TIFFs may be single-channel or multi, ensure RGB
        if img_np.ndim == 2:
            img_np = np.stack([img_np] * 3, axis=-1)
        elif img_np.shape[2] > 3:
            img_np = img_np[:, :, :3]
        img = T.ToTensor()(img_np)

        boxes, labels = load_xml_annotations(self.ann_paths[idx])
        target = {'boxes': boxes, 'labels': labels}

        if self.transforms:
            img, target = self.transforms(img, target)
        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))

In [4]:
# Prepare train/val/test splits
orig_dir = f"C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-original"
diag_dir = f"C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-diagnostic"
cpg_dir  = f"C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-cpg"
ann_dir  = f"C:/Users/luukn/AIMI_MONKEY2/monkey-training/annotations/xml"

ds_orig = MNLDataset(orig_dir, ann_dir)
ds_diag = MNLDataset(diag_dir, ann_dir)
full_train = torch.utils.data.ConcatDataset([ds_orig, ds_diag])

# 80/20 split (only for pas-diagnostic and pas-original, test set is 100% pas-cpg)
val_size = int(0.2 * len(full_train))
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])
test_ds = MNLDataset(cpg_dir, ann_dir)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=2, shuffle=False, collate_fn=collate_fn)


## Retrieve FasterRCNN model

In [6]:
# Build Faster R-CNN model
def get_model(num_classes=2):
    # load pre-trained model
    model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
    in_feats = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_feats, num_classes)
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Print model summary
summary(model) 

Layer (type:depth-idx)                                  Param #
FasterRCNN                                              --
├─GeneralizedRCNNTransform: 1-1                         --
├─BackboneWithFPN: 1-2                                  --
│    └─IntermediateLayerGetter: 2-1                     --
│    │    └─Conv2d: 3-1                                 (9,408)
│    │    └─FrozenBatchNorm2d: 3-2                      --
│    │    └─ReLU: 3-3                                   --
│    │    └─MaxPool2d: 3-4                              --
│    │    └─Sequential: 3-5                             (212,992)
│    │    └─Sequential: 3-6                             1,212,416
│    │    └─Sequential: 3-7                             7,077,888
│    │    └─Sequential: 3-8                             14,942,208
│    └─FeaturePyramidNetwork: 2-2                       --
│    │    └─ModuleList: 3-9                             984,064
│    │    └─ModuleList: 3-10                            2,360,320
│    

## Train Model

In [6]:
NR_EPOCHS = 1

In [None]:
# Training loop
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0
    for imgs, targets in loader:
        imgs = [img.to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(imgs, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
    return total_loss / len(loader)

for epoch in range(NR_EPOCHS):
    loss = train_one_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1:02d}, Loss: {loss:.4f}")

RuntimeError: [enforce fail at alloc_cpu.cpp:114] data. DefaultCPUAllocator: not enough memory: you tried to allocate 99651944448 bytes.

## Evaluate on pas-cpg

In [None]:
# Evaluate on pas-cpg
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_predictions = []
    for imgs, targets in loader:
        imgs = [img.to(device) for img in imgs]
        outputs = model(imgs)
        all_predictions.extend([{k: v.cpu() for k, v in out.items()} for out in outputs])
    return all_predictions

preds = evaluate(model, test_loader, device)
torch.save(preds, 'predictions_cpg.pt')
print("Saved predictions_cpg.pt")