# Case Study Owkin - Pathology AI - 2025

## Context

Cell segmentation in [hematoxylin and eosin (H&E)](https://www.cancer.gov/publications/dictionaries/cancer-terms/def/hematoxylin-and-eosin-staining) stained histopathology slides is a critical task in digital pathology, aiding in diagnosing diseases like cancer and quantifying key [biomarkers](https://en.wikipedia.org/wiki/Biomarker). Accurate segmentation enables automated workflows for identifying diverse cell types, improving consistency and efficiency in pathology assessments. However, the heterogeneity of tissue structures, staining variability, and presence of rare cell types make this task challenging. Developing robust models can enhance the precision of diagnoses, streamline workflows, and facilitate personalized medicine by providing actionable insights from tissue samples.

## Accessing data

For this challenge, you will have to work with the ConSep dataset. We have already preprocessed the dataset and tiled it for you. You can find it [here](https://www.kaggle.com/datasets/rftexas/tiled-consep-224x224px/). This dataset contains two subdirectories:

1. `tiles`: contains all the raw `.png` images,
2. `labels`: contains all the `*.mat` objects that store the ground truth labels.

We already wrote the data loaders that fetch both the images and the ground truth labels from both subdirectories.

We also provide the [weights](https://www.kaggle.com/models/afiltowk/phikon/pyTorch/default) of the Phikon backbone, which you will need for the model to run.

## Accessing GPUs

For this challenge, you will need access to GPUs to train your models. On the Kaggle Notebook interface, on the right-side bar, go to Session options > Accelerator, and set it to GPU P100. This will provide you with all the compute power you need to train a model. You will be granted 30 hours available for a week. **Beware of switching off your environment when you are done working !** Otherwise, GPU hours will be consumed.

## Persistence of your environment

To save files and outputs you may generate in your local environment at `/kaggle/working`, **we strongly advise you to set the persistence**. On the Kaggle Notebook interface, on the right-side bar, go to Session options > Persistence > Variables and Files. Otherwise, you will loose everything when relaunching your environment !

## Task description

Your task is to implement a training loop for a cell segmentation model called [CellViT](https://arxiv.org/abs/2306.15350).
This model is fitted on the ConSep dataset, which contains 977 images with corresponding cell instance segmentation masks. We already provide the model architecture, the loss functions and the data loaders.

We ask you to implement the training loop, to train the model and properly validate the model performance on the dataset. You will be evaluated on the clarity and efficiency of your implementation, evaluation procedure, and results analysis. We expect you to implement (at least explain) how you evaluated (would evaluate) the convergence of your model: metrics, visualization, etc.

You can find below the structure of the package that we provide, and that contains helper functions including visualization scripts.

```shell
.
├── __init__.py
├── data
│   ├── __init__.py
│   ├── dataset.py
│   ├── label.py
│   ├── label_transform.py
│   └── tiled_dataset.py
├── loss
│   ├── __init__.py
│   ├── cellvit_loss.py
│   ├── dice.py
│   ├── focal.py
│   ├── focal_tversky.py
│   └── msge.py
├── model
│   ├── __init__.py
│   ├── cellvit.py
│   ├── decoder.py
│   ├── extractor
│   ├── neck.py
│   └── utils.py
├── postprocess
│   ├── __init__.py
│   └── instance_map.py
└── viz
    ├── __init__.py
    ├── utils.py
    └── visualize.py
```

**Important**: Take notes about what you're trying and what's working/not working so that we can understand your thought process.

Once you have a baseline, we expect you to propose new ideas to improve the model performance on ConSep. It can be ideas in the literature or your own ideas. Bonus points will be granted if you manage to implement and evaluate your ideas.

Among other ideas, you might want to explore the following but we don't expect you to spend too much time on it. Those are ranked by increasing complexity of implementation:

1. How to make the training loop more efficient in terms of speed and / or GPU RAM?
   
2. How would you implement data augmentation? How would you implement it within the present code? You might want to copy paste code from `data/dataset::TrainingDataset` into your notebook to modify the training augmentations.
   
3. How would you use the bigger version of Phikon (ViT-B), which is [Phikon-v2](https://huggingface.co/owkin/phikon-v2) (ViT-Large)? How would you implement it and what are the pros / cons of using such model? You might want to copy-paste code from `model/extractor/extractor::PhikonViT` and try adapting it to Phikon-v2, which can be loaded using the `transformers` API (Hugging Face). You will find a code snippet to generate features for this model in the last section.

## Deliverables

You have 7 days (maximum) to work on these different points. Nonetheless, do not spend too much time on it!
We expect that this challenge should not take you more than 10 hours of cumulative work.

Once you are finished, we ask you to share with us:

1. This notebook with the code completed
2. A 3-page report that details your understanding of the challenge, your thought process, ideas you tried and
difficulties you faced (if any).

## Bonus: How to use Phikon-v2

```python
import requests
from PIL import Image
import torch
from transformers import AutoImageProcessor, AutoModel


# Load an image
image = Image.open(
    requests.get(
        "https://github.com/owkin/HistoSSLscaling/blob/main/assets/example.tif?raw=true",
        stream=True
    ).raw
)

# Load phikon-v2
processor = AutoImageProcessor.from_pretrained("owkin/phikon-v2")
model = AutoModel.from_pretrained("owkin/phikon-v2")
model.eval()

# Process the image
inputs = processor(image, return_tensors="pt")

# Get the features
with torch.inference_mode():
    outputs = model(**inputs)
    features = outputs.last_hidden_state[:, 0, :]  # (1, 1024) shape

assert features.shape == (1, 1024)
```

In [8]:
import sys
from utils import fetch_kaggle_paths

weight_path, data_path, helper_path = fetch_kaggle_paths()
sys.path.append(helper_path)

In [None]:
from math import ceil, floor
from pathlib import Path


import torch
from owkin_case_study.data.dataset import TrainDataset
from owkin_case_study.data.tiled_dataset import TiledTrainingDataset
from owkin_case_study.loss import CellViTCriterion
from owkin_case_study.model import CellViT
from owkin_case_study.model.extractor import PhikonViT
from torch.utils import data
from torch.utils.data import DataLoader

from utils import NUMBER_CELL_TYPES


############################
# We show below how to load the data and create data loaders that can be used to
# train and validate your model. Feel free to change how to load your data, how
# you validate the model. This is just an example.
############################


ddb_dataset = TiledTrainingDataset.from_tiled_datasets(
    *[TiledTrainingDataset(root_path=data_path)]
)

tiles_path = Path(data_path) / "tiles"

train_tiles = [x.stem for x in tiles_path.glob("train_*.png")]
test_tiles = [x.stem for x in tiles_path.glob("test_*.png")]

dataset_train, dataset_val = data.random_split(
    TrainDataset(ddb_dataset, subset=train_tiles),
    [
        floor(0.8 * len(train_tiles)),
        ceil(
            0.2 * len(train_tiles),
        ),
    ],
    generator=torch.Generator().manual_seed(42),
)
dataset_test = TrainDataset(ddb_dataset, subset=test_tiles)


print(
    f"Dataset loaded! Found {len(train_tiles)} training/validation samples and {len(dataset_test)} test samples."
)

train_loader = DataLoader(
    dataset_train, batch_size=2, shuffle=True, pin_memory=True, num_workers=7
)
val_loader = DataLoader(
    dataset_val, batch_size=16, shuffle=False, pin_memory=True, num_workers=7
)
test_loader = DataLoader(
    dataset_test, batch_size=16, shuffle=False, pin_memory=True, num_workers=7
)


###########################

backbone = PhikonViT(weights_path=weight_path)

model = CellViT(backbone=backbone, number_cell_types=NUMBER_CELL_TYPES)

criterion = CellViTCriterion(num_classes=NUMBER_CELL_TYPES)

In [None]:
import lightning.pytorch as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import WandbLogger
from torch.nn.functional import l1_loss, mse_loss
from torchmetrics.functional.classification import accuracy

import wandb
from lightning_model import LightningModel
from utils import partial

litmodel = LightningModel(
    model,
    criterion,
    clf_keys={"np", "tp"},
    reg_keys={"hv"},
    metrics_dict=dict(
        np=[partial(accuracy, task="multiclass", num_classes=2)],
        tp=[partial(accuracy, task="multiclass", num_classes=NUMBER_CELL_TYPES)],
        hv=[mse_loss, l1_loss],
    ),
)

wandb.init(project="Case_study_Pathology_AI_2025")
trainer = L.Trainer(
    max_epochs=1,
    logger=WandbLogger(),
    callbacks=[ModelCheckpoint(monitor="val_loss", mode="max")],
    accelerator="cpu",
)
# trainer.validate(litmodel, val_loader)
trainer.fit(litmodel, train_loader, val_loader)
wandb.finish()