# Patch Classification Training

### Part 1: Find all the patches in the small training set of Camelyon16

In [1]:
import repath.data.datasets.camelyon17 as camelyon17
from repath.preprocess.patching import GridPatchFinder
from repath.preprocess.patching import SlidesIndex
from repath.preprocess.tissue_detection import TissueDetectorGreyScale

# index all the patches for the camelyon16 dataset
train_data = camelyon17.training_small()
tissue_detector = TissueDetectorGreyScale()
patch_finder = GridPatchFinder(8, 0, 256, 256)
train_patches = SlidesIndex.index_dataset(train_data, tissue_detector, patch_finder)
train_patches.summary()

indexing patient_084_node_4.tif
indexing patient_009_node_1.tif


AssertionError: Unknown annoation group encountered.

In [None]:
from repath.preprocess.sampling import split_camelyon16
from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" / "repath" / experiment_name

train, valid = split_camelyon16(train_patches, 0.7)
#train.save(experiment_root / "train_index")
#valid.save(experiment_root / "valid_index")
for sl in train:
    print("train", sl.slide_idx)
    
for sl in valid:
    print("valid", sl.slide_idx)

In [None]:
from repath.preprocess.sampling import balanced_sample

# do the train validate split
train, valid = split_camelyon16(train_patches, 0.7)
train_samples = balanced_sample(train, 2800)
valid_samples = balanced_sample(valid, 1200)

# save out all the patches
train_samples.save_patches(experiment_root / "training_patches")
valid_samples.save_patches(experiment_root / "validation_patches")

## Part 2: Train patch classifier

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
import torch
from torchvision.datasets import ImageFolder
from repath.patch_classification.models.simple import Backbone

class PatchClassifier(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, pred = torch.max(logits, 1)
        accuracy = Accuracy()
        accu = accuracy(pred, labels)
        return accu

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        x = torch.log_softmax(x, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        accu = self.accuracy(logits, y)
        self.log(f"{label}_loss", loss)
        self.log(f"{label}_accuracy", accu)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.opt
        return optimizer
    
# prepare our data
batch_size = 128
train_set = ImageFolder(root=experiment_root / "training_patches")
valid_set = ImageFolder(root=experiment_root / "validation_patches")
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint-{epoch:02d}-{val_loss:.2f}.ckpt",
    save_top_k=1,
    mode="min",
)

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)