# Try 2

In [7]:
import os
import glob
import xml.etree.ElementTree as ET
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import openslide
from torchvision import transforms as T
from torchinfo import summary
import numpy as np
import tifffile

from openslide import OpenSlideError


from torchvision.models.detection.ssdlite import ssdlite320_mobilenet_v3_large
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
from torchvision.models.mobilenet import MobileNet_V3_Large_Weights

In [8]:
# Helper: load XML point annotations from ASAP-style XML (Dot annotations)
def load_xml_annotations(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    boxes, labels = [], []
    # annotations under /ASAP_Annotations/Annotations/Annotation
    for ann in root.findall('.//Annotation'):
        coords = ann.find('Coordinates/Coordinate')
        x = float(coords.get('X'))
        y = float(coords.get('Y'))
        # create tiny box around point (e.g. 1x1 pixel)
        boxes.append([x, y, x+1.0, y+1.0])
        labels.append(1)
    return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)


class PatchDataset(Dataset):
    def __init__(self, img_dir, ann_dir, mask_dir=None, patch_size=512, transforms=None):
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.tif')))
        self.ann_paths = []
        for img_path in self.img_paths:
            base = os.path.splitext(os.path.basename(img_path))[0].split('_')[:2]  # Take only C_P000031
            base = '_'.join(base)
            match = os.path.join(ann_dir, f"{base}.xml")
            if os.path.exists(match):
                self.ann_paths.append(match)
            else:
                raise FileNotFoundError(f"Annotation XML for slide {base} not found in {ann_dir}")
        self.mask_paths = ([os.path.join(mask_dir, os.path.basename(p)) for p in self.img_paths] if mask_dir else [None]*len(self.img_paths))
        self.patch_size = patch_size
        self.transforms = transforms or T.Compose([T.Resize((320,320)), T.ToTensor()])
                # Precompute ROI bounding boxes from binary mask (values > 0 are ROI)
        self.rois = []
        for mask_path in self.mask_paths:
            if mask_path and os.path.exists(mask_path):
                mask_arr = tifffile.imread(mask_path)
                # treat any non-zero pixel as ROI
                mask = mask_arr > 0
                coords = np.argwhere(mask)
                if coords.size > 0:
                    y0, x0 = coords.min(axis=0)
                    y1, x1 = coords.max(axis=0)
                    self.rois.append((x0, y0, x1, y1))
                else:
                    self.rois.append(None)
            else:
                self.rois.append(None)

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        slide_path = self.img_paths[idx]
        slide = openslide.OpenSlide(slide_path)
        w0, h0 = slide.level_dimensions[0]
        roi = self.rois[idx] or (0, 0, w0, h0)
        x0, y0, x1, y1 = roi
        # clamp ROI within slide bounds
        x1 = min(x1, w0); y1 = min(y1, h0)
        w_roi, h_roi = x1 - x0, y1 - y0
        try:
            region = slide.read_region((x0, y0), 0, (w_roi, h_roi)).convert('RGB')
        except OpenSlideError:
            # fallback: read a smaller region around center
            cx = w0 // 2; cy = h0 // 2
            half = self.patch_size // 2
            x0 = max(cx - half, 0); y0 = max(cy - half, 0)
            region = slide.read_region((x0, y0), 0, (self.patch_size, self.patch_size)).convert('RGB')
            boxes, labels = load_xml_annotations(self.ann_paths[idx])
            return [self.transforms(region)], [{'boxes': torch.tensor([[0,0,self.patch_size,self.patch_size]]), 'labels': torch.tensor([1])}]
        # Load annotations and adjust
        boxes, labels = load_xml_annotations(self.ann_paths[idx])
        if roi != (0,0,w0,h0):
            boxes -= torch.tensor([x0, y0, x0, y0], dtype=torch.float32)
        # Sample patches
        patches, targets = [], []
        half = self.patch_size // 2
        for box, label in zip(boxes, labels):
            cx = int((box[0] + box[2]) / 2); cy = int((box[1] + box[3]) / 2)
            x_start = np.clip(cx - half, 0, w_roi - self.patch_size)
            y_start = np.clip(cy - half, 0, h_roi - self.patch_size)
            patch = region.crop((x_start, y_start, x_start + self.patch_size, y_start + self.patch_size))
            tensor = self.transforms(patch)
            adj = box - torch.tensor([x_start, y_start, x_start, y_start], dtype=torch.float32)
            patches.append(tensor)
            targets.append({'boxes': adj.unsqueeze(0), 'labels': label.unsqueeze(0)})
        return patches, targets
    

# Custom collate to flatten
def collate_fn(batch):
    imgs, targets = [], []
    for patch_list, targ_list in batch:
        imgs.extend(patch_list)
        targets.extend(targ_list)
    return imgs, targets


In [9]:
orig_dir = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-original"
diag_dir = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-diagnostic"
cpg_dir  = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-cpg"
ann_dir  = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/annotations/xml"
mask_dir = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/tissue-masks"

# Build datasets and loaders
full_ds = torch.utils.data.ConcatDataset([
    PatchDataset(orig_dir, ann_dir, mask_dir),
    PatchDataset(diag_dir, ann_dir, mask_dir)
])
val_size = int(0.2 * len(full_ds))
train_size = len(full_ds) - val_size
train_ds, val_ds = random_split(full_ds, [train_size, val_size])
test_ds = PatchDataset(cpg_dir, ann_dir, None)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=2, shuffle=False, collate_fn=collate_fn)


In [10]:
# Build SSD-Lite MobileNetV3 model for MNL detection
def get_model(num_classes=2):
    # use pretrained MobilenetV3 backbone, random head
    model = ssdlite320_mobilenet_v3_large(
        weights=None,
        weights_backbone=MobileNet_V3_Large_Weights.IMAGENET1K_V2,
        num_classes=num_classes
    )
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Print summary
summary(model, input_size=[(4, 3, 320, 320)])

Layer (type:depth-idx)                                       Output Shape              Param #
SSD                                                          [133, 4]                  --
├─GeneralizedRCNNTransform: 1-1                              [4, 3, 320, 320]          --
├─SSDLiteFeatureExtractorMobileNet: 1-2                      [4, 128, 1, 1]            --
│    └─Sequential: 2-1                                       --                        --
│    │    └─Sequential: 3-1                                  [4, 672, 20, 20]          869,096
│    │    └─Sequential: 3-2                                  [4, 960, 10, 10]          2,102,856
│    └─ModuleList: 2-2                                       --                        --
│    │    └─Sequential: 3-3                                  [4, 512, 5, 5]            381,184
│    │    └─Sequential: 3-4                                  [4, 256, 3, 3]            100,480
│    │    └─Sequential: 3-5                                  [4, 256, 2, 

In [11]:
NR_EPOCHS = 1

In [12]:
for epoch in range(NR_EPOCHS):
    model.train()
    total_loss = 0
    batch = 0 
    for imgs, targets in train_loader:
        print('START BATCH: ', batch)
        imgs_resized = [T.Resize((320, 320))(img).to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(imgs_resized, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
        print(f"Epoch {epoch+1}, Batch {batch}, Loss: {losses.item():.4f}")
    print('---------------------------------------------------------------------------')
    print(f"Epoch {epoch+1:02d}, Loss: {total_loss/len(train_loader):.4f}")

OpenSlideError: Cannot read raw tile

In [None]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_predictions = []
    for imgs, targets in loader:
        imgs_resized = [T.Resize((320, 320))(img).to(device) for img in imgs]
        outputs = model(imgs_resized)
        all_predictions.extend([{k: v.cpu() for k, v in out.items()} for out in outputs])
    return all_predictions

# Evaluate on test
preds = evaluate(model, test_loader, device)
torch.save(preds, 'predictions_cpg_ssd.pt')
print("Saved SSD predictions_cpg_ssd.pt")


# Try 1

In [1]:
import os
import glob
import xml.etree.ElementTree as ET
from PIL import Image
# Allow loading of very large images
Image.MAX_IMAGE_PIXELS = None
import tifffile

import torch
from torch.utils.data import Dataset, DataLoader, random_split
import torchvision
from torchvision.models.detection.ssdlite import ssdlite320_mobilenet_v3_large
from torchvision.models.detection.ssdlite import SSDLiteClassificationHead
from torchvision.models.mobilenet import MobileNet_V3_Large_Weights
from torchvision import transforms as T
from torchinfo import summary
import numpy as np

In [2]:
# Helper: load XML annotations
def load_xml_annotations(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()
    boxes, labels = [], []
    for obj in root.findall('object'):
        name = obj.find('name').text
        b = obj.find('bndbox')
        xmin, ymin = float(b.find('xmin').text), float(b.find('ymin').text)
        xmax, ymax = float(b.find('xmax').text), float(b.find('ymax').text)
        boxes.append([xmin, ymin, xmax, ymax])
        labels.append(1)
    return torch.tensor(boxes, dtype=torch.float32), torch.tensor(labels, dtype=torch.int64)

def crop_tissue(img_np, padding=10, bg_threshold=240):
    # img_np: HxWxC uint8
    gray = np.mean(img_np, axis=2)
    mask = gray < bg_threshold
    coords = np.argwhere(mask)
    if coords.size == 0:
        return img_np
    y0, x0 = coords.min(axis=0)
    y1, x1 = coords.max(axis=0)
    # pad and clip
    y0 = max(y0 - padding, 0)
    x0 = max(x0 - padding, 0)
    y1 = min(y1 + padding, img_np.shape[0])
    x1 = min(x1 + padding, img_np.shape[1])
    return img_np[y0:y1, x0:x1]

class MNLDataset(Dataset):
    def __init__(self, img_dir, ann_dir, transforms=None, crop_whitespace=False):
        self.img_paths = sorted(glob.glob(os.path.join(img_dir, '*.tif')))
        self.ann_paths = [
            os.path.join(ann_dir, os.path.basename(p).replace('.tif', '.xml'))
            for p in self.img_paths
        ]
        self.transforms = transforms
        self.crop_whitespace = crop_whitespace

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img_np = tifffile.imread(self.img_paths[idx])
        # ensure RGB
        if img_np.ndim == 2:
            img_np = np.stack([img_np] * 3, axis=-1)
        elif img_np.shape[2] > 3:
            img_np = img_np[:, :, :3]
        # crop
        if self.crop_whitespace:
            img_np = crop_tissue(img_np)
        img = T.ToTensor()(img_np)

        boxes, labels = load_xml_annotations(self.ann_paths[idx])
        # adjust boxes if cropped
        if self.crop_whitespace:
            h_off, w_off = np.argwhere(np.mean(img_np, axis=2) < 240).min(axis=0)
            boxes -= torch.tensor([w_off, h_off, w_off, h_off])
            # clamp
            boxes[:, [0,2]] = boxes[:, [0,2]].clamp(0, img.shape[2])
            boxes[:, [1,3]] = boxes[:, [1,3]].clamp(0, img.shape[1])
        target = {'boxes': boxes, 'labels': labels}

        if self.transforms:
            img, target = self.transforms(img, target)
        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))

In [4]:
# Prepare datasets and loaders
orig_dir = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-original"
diag_dir = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-diagnostic"
cpg_dir  = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/images/pas-cpg"
ann_dir  = "C:/Users/luukn/AIMI_MONKEY2/monkey-training/annotations/xml"

ds_orig = MNLDataset(orig_dir, ann_dir)
ds_diag = MNLDataset(diag_dir, ann_dir)
full_train = torch.utils.data.ConcatDataset([ds_orig, ds_diag])

# Enable cropping in diagnostic/original, keep test uncropped
train_ds = MNLDataset(orig_dir, ann_dir, crop_whitespace=True)
train_diag = MNLDataset(diag_dir, ann_dir, crop_whitespace=True)
full_train = torch.utils.data.ConcatDataset([train_ds, train_diag])
val_size = int(0.2 * len(full_train))
train_size = len(full_train) - val_size
train_ds, val_ds = random_split(full_train, [train_size, val_size])
test_ds = MNLDataset(cpg_dir, ann_dir, crop_whitespace=True)

# loaders
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds,   batch_size=2, shuffle=False, collate_fn=collate_fn)
test_loader  = DataLoader(test_ds,  batch_size=2, shuffle=False, collate_fn=collate_fn)

In [5]:
# Build SSD-Lite MobileNetV3 model for MNL detection
def get_model(num_classes=2):
    # use pretrained MobilenetV3 backbone, random head
    model = ssdlite320_mobilenet_v3_large(
        weights=None,
        weights_backbone=MobileNet_V3_Large_Weights.IMAGENET1K_V2,
        num_classes=num_classes
    )
    return model

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = get_model(num_classes=2).to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9, weight_decay=0.0005)

# Print summary
summary(model, input_size=[(4, 3, 320, 320)])

Layer (type:depth-idx)                                       Output Shape              Param #
SSD                                                          [156, 4]                  --
├─GeneralizedRCNNTransform: 1-1                              [4, 3, 320, 320]          --
├─SSDLiteFeatureExtractorMobileNet: 1-2                      [4, 128, 1, 1]            --
│    └─Sequential: 2-1                                       --                        --
│    │    └─Sequential: 3-1                                  [4, 672, 20, 20]          869,096
│    │    └─Sequential: 3-2                                  [4, 960, 10, 10]          2,102,856
│    └─ModuleList: 2-2                                       --                        --
│    │    └─Sequential: 3-3                                  [4, 512, 5, 5]            381,184
│    │    └─Sequential: 3-4                                  [4, 256, 3, 3]            100,480
│    │    └─Sequential: 3-5                                  [4, 256, 2, 

In [18]:
NR_EPOCHS = 1

In [19]:
for epoch in range(NR_EPOCHS):
    # training loop can include resizing and transforms similar to eval
    model.train()
    total_loss = 0
    batch = 0 
    for imgs, targets in train_loader:
        print('START BATCH: ', batch)
        imgs_resized = [T.Resize((320, 320))(img).to(device) for img in imgs]
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
        loss_dict = model(imgs_resized, targets)
        losses = sum(loss for loss in loss_dict.values())
        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        total_loss += losses.item()
        print(f"Epoch {epoch+1}, Batch {batch}, Loss: {losses.item():.4f}")
    print('---------------------------------------------------------------------------')
    print(f"Epoch {epoch+1:02d}, Loss: {total_loss/len(train_loader):.4f}")

MemoryError: Unable to allocate 40.0 GiB for an array with shape (1, 1, 155904, 91904, 3) and data type uint8

In [None]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    all_predictions = []
    for imgs, targets in loader:
        imgs_resized = [T.Resize((320, 320))(img).to(device) for img in imgs]
        outputs = model(imgs_resized)
        all_predictions.extend([{k: v.cpu() for k, v in out.items()} for out in outputs])
    return all_predictions

# Evaluate on test
preds = evaluate(model, test_loader, device)
torch.save(preds, 'predictions_cpg_ssd.pt')
print("Saved SSD predictions_cpg_ssd.pt")
