# Patch Classification Training

### Part 1: Find all the patches in the small training set of Camelyon16

In [1]:
import repath.data.datasets.camelyon16 as camelyon16
from repath.preprocess.patching import GridPatchFinder
from repath.preprocess.patching import SlidesIndex
from repath.preprocess.tissue_detection import TissueDetectorOTSU

# index all the patches for the camelyon16 dataset
train_data = camelyon16.training_small()
tissue_detector = TissueDetectorOTSU()
patch_finder = GridPatchFinder(6, 0, 256, 256)
train_patches = SlidesIndex.index_dataset(train_data, tissue_detector, patch_finder)
train_patches.summary()

indexing normal_038.tif
indexing tumor_048.tif
indexing normal_058.tif
indexing normal_115.tif
indexing tumor_111.tif
indexing normal_136.tif
indexing tumor_006.tif
indexing tumor_090.tif
indexing normal_129.tif
indexing normal_027.tif
indexing normal_052.tif
indexing normal_082.tif


Unnamed: 0,slide_path,slide_label,background,normal,tumor
0,normal/normal_038.tif,normal,326683,3365,0
1,tumor/tumor_048.tif,tumor,283824,40047,65
2,normal/normal_058.tif,normal,304555,21673,0
3,normal/normal_115.tif,normal,119997,13955,0
4,tumor/tumor_111.tif,tumor,308863,20219,202
5,normal/normal_136.tif,normal,109954,37886,0
6,tumor/tumor_006.tif,tumor,275070,54908,70
7,tumor/tumor_090.tif,tumor,228774,24320,3834
8,normal/normal_129.tif,normal,235850,49750,0
9,normal/normal_027.tif,normal,328206,1842,0


In [2]:
from repath.preprocess.sampling import split_camelyon16
from repath.utils.paths import project_root

experiment_name = "example"
experiment_root = project_root() / "experiments" / "repath" / experiment_name

train, valid = split_camelyon16(train_patches, 0.7)
train.save(experiment_root / "train_index")
valid.save(experiment_root / "valid_index")



In [3]:
from repath.preprocess.sampling import balanced_sample

# do the train validate split
train, valid = split_camelyon16(train_patches, 0.7)
train_samples = balanced_sample(train, 2800)
valid_samples = balanced_sample(valid, 1200)

# save out all the patches
train_samples.save_patches(experiment_root / "training_patches")
valid_samples.save_patches(experiment_root / "validation_patches")

background    1956357
normal         153168
tumor             267
dtype: int64
(153168, 4) 1000
(267, 4) 267
background    2027177
normal         178631
tumor            3904
dtype: int64
(178631, 4) 1200
(3904, 4) 1200
Writing patches for normal/normal_038.tif
Writing patches for tumor/tumor_048.tif
Writing patches for normal/normal_058.tif
Writing patches for tumor/tumor_111.tif
Writing patches for normal/normal_136.tif
Writing patches for normal/normal_027.tif
Writing patches for normal/normal_082.tif
Writing patches for normal/normal_038.tif
Writing patches for normal/normal_058.tif
Writing patches for tumor/tumor_006.tif
Writing patches for tumor/tumor_090.tif
Writing patches for normal/normal_027.tif
Writing patches for normal/normal_052.tif
Writing patches for normal/normal_082.tif


## Part 2: Train patch classifier

In [None]:
import pytorch_lightning as pl
from pytorch_lightning.metrics import Accuracy
import torch
from torchvision.datasets import ImageFolder
from repath.patch_classification.models.simple import Backbone

class PatchClassifier(pl.LightningModule):
    def __init__(self, model) -> None:
        super().__init__()
        self.model = model

    def cross_entropy_loss(self, logits, labels):
        return F.nll_loss(logits, labels)

    def accuracy(self, logits, labels):
        _, pred = torch.max(logits, 1)
        accuracy = Accuracy()
        accu = accuracy(pred, labels)
        return accu

    def step(self, batch, batch_idx, label):
        x, y = batch
        logits = self.model(x)
        x = torch.log_softmax(x, dim=1)
        loss = self.cross_entropy_loss(logits, y)
        accu = self.accuracy(logits, y)
        self.log(f"{label}_loss", loss)
        self.log(f"{label}_accuracy", accu)
        return loss

    def training_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self.step(batch, batch_idx, "val")

    def configure_optimizers(self):
        optimizer = torch.opt
        return optimizer
    
# prepare our data
batch_size = 128
train_set = ImageFolder(root=experiment_root / "training_patches")
valid_set = ImageFolder(root=experiment_root / "validation_patches")
train_loader = DataLoader(train_set, batch_size=batch_size)
valid_loader = DataLoader(valid_set, batch_size=batch_size)

# configure logging and checkpoints
checkpoint_callback = ModelCheckpoint(
    monitor="val_loss",
    dirpath=experiment_root / "patch_model",
    filename=f"checkpoint-{epoch:02d}-{val_loss:.2f}.ckpt",
    save_top_k=1,
    mode="min",
)

# train our model
model = Backbone()
classifier = PatchClassifier(model)
trainer = pl.Trainer(callbacks=[checkpoint_callback])
trainer.fit(classifier, train_dataloader=train_loader, val_dataloaders=valid_loader)