reference: https://colab.research.google.com/drive/1yny79jQYAxN-ho5Fei2cXGuPEzl7kxUs#scrollTo=rvuec_zBJCy3


In [None]:
from pathlib import Path

import pandas as pd


def prepare_dataset(
    df: pd.DataFrame, root_dir: Path, patches_subfolder: str = "patches"
):
    """
    Takes a openend dataframe as input and returns the dataframe expected by the
    pytorch datasets
    """
    expected_cols_present = set(["slideName", "patchName", "label"]) <= set(df.columns)
    if not expected_cols_present:
        raise ValueError("Missing mendatory columns in label CSV file.")

    df["imgPath"] = (
        root_dir.as_posix()
        + "/"
        + df.slideName
        + "/"
        + patches_subfolder
        + "/"
        + df.patchName
        + ".jpg"
    )
    return df[["imgPath", "label"]]


df = pd.read_csv("../patches_train/labels.csv", index_col=0)
df = prepare_dataset(df, Path("../patches_train/"))
df.head()

In [None]:
import torch
from PIL import Image
from torch.utils.data import Dataset


# Custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx, 0]
        image = Image.open(img_path).convert("RGB")
        label = int(self.dataframe.iloc[idx, 1])
        return {"image": image, "label": label}


dataset = CustomImageDataset(df)

In [None]:
import torch
from PIL import Image
from transformers import AutoImageProcessor, ViTModel

# load an image
image = Image.open("../patches_train/0SJ4TLT74S_b/patches/0_7_0_3584_1_256_256.jpg")

# load phikon
image_processor = AutoImageProcessor.from_pretrained("owkin/phikon", use_fast=True)
model = ViTModel.from_pretrained("owkin/phikon", add_pooling_layer=False)

# process the image
inputs = image_processor(image, return_tensors="pt")

# get the features
with torch.no_grad():
    outputs = model(**inputs)
    features = outputs.last_hidden_state[:, 0, :]  # (1, 768) shape

In [None]:
import timm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_class_weight
import pandas as pd
import numpy as np
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    LearningRateMonitor,
)
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import AutoImageProcessor, ViTModel


# Custom dataset
class CustomImageDataset(Dataset):
    def __init__(self, dataframe, transform=None):
        self.dataframe = dataframe
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe.iloc[idx, 0]
        image = Image.open(img_path).convert("RGB")
        label = int(self.dataframe.iloc[idx, 1])

        if self.transform:
            image = self.transform(image)

        return image, label


# Focal Loss
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, weight=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.weight = weight
        self.bce_loss = nn.BCEWithLogitsLoss(pos_weight=weight)

    def forward(self, logits, targets):
        bce_loss = self.bce_loss(logits, targets)
        pt = torch.exp(-bce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * bce_loss
        return focal_loss


# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    "train": transforms.Compose(
        [
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(45),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
            transforms.GaussianBlur(kernel_size=(7, 7), sigma=(0.001, 0.1)),
            transforms.ToTensor(),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    ),
}


class ImageClassificationModel(pl.LightningModule):
    def __init__(self, class_weights):
        super(ImageClassificationModel, self).__init__()
        # Load the new encoder model from the timm library
        self.encoder = timm.create_model(
            model_name="hf-hub:1aurent/vit_base_patch16_224.owkin_pancancer",
            pretrained=True,
        ).eval()

        # Freeze the encoder model
        for param in self.encoder.parameters():
            param.requires_grad = False

        # Get model specific transforms (normalization, resize)
        self.data_config = timm.data.resolve_model_data_config(self.encoder)
        self.transforms = timm.data.create_transform(
            **self.data_config, is_training=False
        )

        # Get the number of features from the encoder
        num_features = self.encoder.num_features

        # Two-layer classifier with ReLU activation
        self.classifier = nn.Sequential(
            nn.Linear(num_features, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
        )

        # Calculate the pos_weight
        self.criterion = FocalLoss(weight=class_weights)

    def forward(self, x):
        # inputs = torch.stack([self.transforms(img) for img in x])
        inputs = self.transforms(x)
        outputs = self.encoder(inputs)
        features = outputs
        return self.classifier(features)

    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        labels = labels.float().view(-1, 1)
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        preds = torch.sigmoid(outputs) > 0.5
        preds = preds.int()
        acc = torch.sum(preds == labels.data).float() / len(labels)
        f1 = f1_score(labels.cpu(), preds.cpu())
        self.log("train_loss", loss, on_epoch=True, prog_bar=True)
        self.log("train_acc", acc, on_epoch=True, prog_bar=True)
        self.log("train_f1", f1, on_epoch=True, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        labels = labels.float().view(-1, 1)
        outputs = self(inputs)
        loss = self.criterion(outputs, labels)
        preds = torch.sigmoid(outputs) > 0.5
        preds = preds.int()
        acc = torch.sum(preds == labels.data).float() / len(labels)
        f1 = f1_score(labels.cpu(), preds.cpu())
        self.log("val_loss", loss, on_epoch=True, prog_bar=True)
        self.log("val_acc", acc, on_epoch=True, prog_bar=True)
        self.log("val_f1", f1, on_epoch=True, prog_bar=True)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=LR)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.5, patience=3, min_lr=1e-7
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"},
        }