## Leukocytes Image Classification with <b>Vision Transformer from Hugging Face 🤗</b> and Pytorch Lightning⚡
![img](https://assets.stickpng.com/images/6308b84661b3e2a522f01468.png)
***
Computer vision project _Leukocytes classification from blood smear images - LCBSI_


The notebook author: @AgataPolejowska

### 🔨 📁 Setup the environment - install and import essential packages

It is recommended to run the notebook within `lcbsi` environment created using the `environment.yml` file provided in the root directory of the project. The instructions for installing the environment are contained in `README.md` file.

In [None]:
!pip install -q transformers datasets
!pip install -q huggingface_hub
!python -c "import wandb" || pip install -q wandb
!python -c "import roboflow" || pip install -q roboflow

!pip install -U -q torch
!pip install -U -q torchvision
!pip install -U -q torchmetrics

!python -c "import pytorch_lightning" || pip install -q pytorch-lightning

!python -c "import matplotlib" || pip install -q matplotlib

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import plotly.express as px
import pytorch_lightning as pl
import seaborn as sns
import torch
import torch.nn as nn
import torchmetrics
import torchvision
from datasets import load_dataset
from roboflow import Roboflow
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from transformers import AdamW, AutoFeatureExtractor, ViTForImageClassification

print(torch.__version__)
print(torchvision.__version__)
print(torchmetrics.__version__)

print(pl.__version__)


sns.set_style()


### 📣📉📈  Configure W&B for results monitoring
![img](https://raw.githubusercontent.com/wandb/assets/main/wandb-logo-yellow-dots-black-wb.svg)

Weight and Biases is used as a MLOps platform in order to track the experiments results.

In [None]:
import wandb

wandb.login()

PROJECT = "lcbsi-vit-pl"
ENTITY = "polejowska"

wandb.init(project=PROJECT, entity=ENTITY)


In [None]:
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger

wandb_logger = WandbLogger(project=PROJECT, log_model="all", score="accuracy")


early_stopping_callback = EarlyStopping(
    monitor="val_loss", min_delta=0.00, patience=3, verbose=False, mode="min"
)

checkpoint_callback = ModelCheckpoint(monitor="val_acc", mode="max")


### 🌱 Ensure results reproducibility by setting the seed for Pytorch Lightning

In [None]:
pl.seed_everything(0)


### 🔧🔒 Create Pytorch Lightning DataModule - LeukocytesDataModule

This module is created so that all the steps needed to process data are encapsulated and easily reusable.

Within this module the following can be performed:
- setting up the dataset by dowloading it from Roboflow
- getting essential information about the dataset
- visualizing sample images from the dataset
- logging to wandb information about the dataset
- <b>getting dataloaders and datasets<b>

In [None]:
class LeukocytesDataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size: int = 32,
        num_workers: int = 4,
        num_classes: int = 5,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.num_classes = num_classes
        self.num_workers = num_workers

    def setup(self):
        rf = Roboflow(api_key="3a7r2eyLT0LA2P5AUyvr")
        workspace = rf.workspace("lcbsiwbc")

        project_train = workspace.project("lcbsi-wbc-train")
        dataset_train = project_train.version(1).download("folder")

        project_valid = workspace.project("lcbsi-wbc-valid")
        dataset_valid = project_valid.version(1).download("folder")

        project_test = workspace.project("lcbsi-wbc")
        dataset_test = project_test.version(1).download("folder")

        self.data_dir = "lcbsi_dataset"
        self.path_to_dataset = Path(self.data_dir)
        os.makedirs(self.data_dir, exist_ok=True)

        def move_folder(src, dst):
            import shutil

            for folder in os.listdir(src):
                new_path = os.path.join(dst, folder)
                if os.path.exists(new_path):
                    # if is directory or file remove it
                    if os.path.isdir(new_path):
                        shutil.rmtree(new_path)
                    else:
                        os.remove(new_path)
                os.rename(os.path.join(src, folder), new_path)

        move_folder(dataset_train.location, self.data_dir)
        move_folder(dataset_valid.location, self.data_dir)
        move_folder(dataset_test.location, self.data_dir)

        print(f"Dataset downloaded to {self.data_dir}")
        print(f"Path to dataset: {self.path_to_dataset}")

    def collate_fn(self, examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        labels = torch.tensor([example["label"] for example in examples])
        return {"pixel_values": pixel_values, "labels": labels}

    def train_dataset(self, train_transforms):
        self.train_dataset = ImageFolder(
            self.path_to_dataset / "train", train_transforms
        )
        return self.train_dataset

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def val_dataset(self, val_transforms):
        self.val_dataset = ImageFolder(self.path_to_dataset / "valid", val_transforms)
        return self.val_dataset

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def test_dataset(self, val_transforms):
        self.test_dataset = ImageFolder(self.path_to_dataset / "test", val_transforms)
        return self.test_dataset

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            num_workers=self.num_workers,
            collate_fn=self.collate_fn,
        )

    def get_info_about_dowloaded_dataset(self):
        version_information = self.project.get_version_information()
        version = version_information[0]
        print(f"Dataset name: {version['name']}")
        print(f"Dataset ID: {version['id']}")
        print(f"Dataset splits: {version['images']}")
        print(f"Dataset splits: {version['splits']}")
        print(f"Dataset augmentations: {version['augmentation']}")
        print(f"Dataset preprocessing: {version['preprocessing']}")
        wandb.config.update(
            {
                "dataset_name": version["name"],
                "dataset_id": version["id"],
                "dataset_splits": version["splits"],
                "dataset_augmentations": version["augmentation"],
                "dataset_preprocessing": version["preprocessing"],
            }
        )

    def visualize_dataset(self):
        test_files = []
        for name in self.class_names:
            test_files.extend(
                [
                    os.path.join(self.data_dir, "test", name, x)
                    for x in os.listdir(os.path.join(self.data_dir, "test", name))
                ]
            )
        print(f"Number of test files: {len(test_files)}")

        with PIL.Image.open(test_files[0]) as img:
            img_size = img.size
        print(f"Image size: {img_size}")

        test_files = np.random.choice(test_files, 9, replace=False)
        _, axes = plt.subplots(3, 3, figsize=(12, 12))
        for i, file in enumerate(test_files):
            with PIL.Image.open(file) as img:
                ax = axes[i // 3, i % 3]
                ax.set_title(os.path.basename(os.path.dirname(file)))
                ax.imshow(img)
                ax.axis("off")
                wandb.log({f"test_images_{i}": wandb.Image(img)})
        plt.tight_layout()
        plt.show()

    def get_classes_info(self):
        self.class_names = sorted(
            [
                name
                for name in os.listdir(os.path.join(self.data_dir, "test"))
                if os.path.isdir(os.path.join(self.data_dir, "test", name))
            ]
        )
        num_classes = len(self.class_names)
        print(f"Class names: {self.class_names} ({num_classes} classes)")


In [None]:
NUM_CLASSES = 5
BATCH_SIZE = 32


In [None]:
leukocytes_data_module = LeukocytesDataModule(
    batch_size=BATCH_SIZE, num_classes=NUM_CLASSES
)
leukocytes_data_module.setup()
leukocytes_data_module.log_module_data_to_wb()


In [None]:
leukocytes_data_module.get_classes_info()
leukocytes_data_module.get_info_about_dowloaded_dataset()
leukocytes_data_module.visualize_dataset()


### 🔧 Configure dataset for Vision Transformer from Hugging Face 🤗

Fine-tuning is performed using Vision Transformer model checkpoint pre-trained on ImageNet-21k (_https://huggingface.co/google/vit-base-patch16-224-in21k_).

In [None]:
MODEL_CHECKPOINT_PATH = "google/vit-base-patch16-224-in21k"

feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_CHECKPOINT_PATH)

print(f"Feature extractor type: {type(feature_extractor)}")
print(f"Feature extractor config: {feature_extractor}")


In [None]:
VIT_INPUT_SIZE = feature_extractor.size
print(f"VIT input size: {VIT_INPUT_SIZE}")

RESIZE_TRANSFORM = transforms.Resize(
    (VIT_INPUT_SIZE["width"], VIT_INPUT_SIZE["height"])
)

NORMALIZE_TRANSFORM = transforms.Normalize(
    mean=feature_extractor.image_mean,
    std=feature_extractor.image_std,
)

TRANSFORMS = transforms.Compose(
    [
        RESIZE_TRANSFORM,
        transforms.ToTensor(),
        NORMALIZE_TRANSFORM,
    ]
)


In [None]:
train_dataset = leukocytes_data_module.train_dataset(TRANSFORMS)
train_loader = leukocytes_data_module.train_dataloader()

val_dataset = leukocytes_data_module.val_dataset(TRANSFORMS)
val_loader = leukocytes_data_module.val_dataloader()

test_dataset = leukocytes_data_module.test_dataset(TRANSFORMS)
test_loader = leukocytes_data_module.test_dataloader()


### 🤗 HuggingFace dataset setup

The dataset can be pushed to  Hugging Face hub.


The following section of the notebook loads the downloaded dataset from local directory, transforms the dataset appropriately (mapping).


In [None]:
dataset = load_dataset(leukocytes_data_module.data_dir)
dataset = dataset.map(
    lambda example: {
        "pixel_values": feature_extractor(
            example["image"], return_tensors="pt"
        ).pixel_values
    },
    batched=True,
)
dataset = dataset.map(
    lambda example: {"label": leukocytes_data_module.class_names[int(example["label"])]}
)
dataset = dataset.map(remove_columns=["image"], batched=True)

print(f"Dataset: {dataset}")
dataset.push_to_hub("polejowska/lcbsi-leukocytes-dataset")


In [None]:
dataset = load_dataset("polejowska/lcbsi-wbc")


In [None]:
train_dataset = dataset["train"]
val_dataset = dataset["validation"]
test_dataset = dataset["test"]


In [None]:
id2label = {id: label for id, label in enumerate(train_dataset.features["label"].names)}
id2label


In [None]:
label2id = {label: id for id, label in id2label.items()}
label2id


In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


In [None]:
train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=TRAIN_BATCH_SIZE
)
val_dataloader = DataLoader(
    val_dataset, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
)
test_dataloader = DataLoader(
    test_dataset, collate_fn=collate_fn, batch_size=EVAL_BATCH_SIZE
)


In [None]:
batch = next(iter(train_dataloader))
for k, v in batch.items():
    if isinstance(v, torch.Tensor):
        print(k, v.shape)


In [None]:
assert batch["pixel_values"].shape == (TRAIN_BATCH_SIZE, 3, 224, 224)
assert batch["labels"].shape == (TRAIN_BATCH_SIZE,)


In [None]:
next(iter(val_dataloader))["pixel_values"].shape


### ✏️💡 Define the ViT model with PyTorch Lightning Module ⚡

In [None]:
class ViTLightningModule(pl.LightningModule):
    def __init__(
        self,
        model_checkpoint_path,
        num_classes,
        criterion=nn.CrossEntropyLoss(),
        learning_rate=5e-5,
    ):
        super(ViTLightningModule, self).__init__()
        self.vit = ViTForImageClassification.from_pretrained(
            model_checkpoint_path,
            num_labels=num_classes,
            id2label=id2label,
            label2id=label2id,
        )
        self.dropout = nn.Dropout(0.1)
        self.num_labels = num_classes
        self.classifier = nn.Linear(self.vit.config.hidden_size, self.num_labels)
        self.criterion = criterion
        self.learning_rate = learning_rate
        for name, param in self.vit.named_parameters():
            if name not in ["classifier.weight", "classifier.bias"]:
                param.requires_grad = False

    def forward(self, pixel_values):
        outputs = self.vit(pixel_values=pixel_values)
        return outputs.logits

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=self.learning_rate)

    def common_step(self, batch, batch_idx):
        pixel_values = batch["pixel_values"]
        labels = batch["labels"]
        logits = self(pixel_values)
        loss = self.criterion(logits, labels)
        predictions = logits.argmax(-1)
        correct = (predictions == labels).sum().item()
        accuracy = correct / pixel_values.shape[0]
        return loss, accuracy

    def training_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("training_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log(
            "training_accuracy", accuracy, on_step=True, on_epoch=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("validation_loss", loss, on_step=True, on_epoch=True, logger=True)
        self.log(
            "validation_accuracy", accuracy, on_step=True, on_epoch=True, logger=True
        )
        return loss

    def test_step(self, batch, batch_idx):
        loss, accuracy = self.common_step(batch, batch_idx)
        self.log("test_loss", loss, prog_bar=True, logger=True)
        self.log("test_acc", accuracy, prog_bar=True, logger=True)
        return loss

    def train_dataloader(self):
        return train_dataloader

    def val_dataloader(self):
        return val_dataloader

    def test_dataloader(self):
        return test_dataloader


In [None]:
model = ViTLightningModule(
    model_checkpoint_path=MODEL_CHECKPOINT_PATH, num_classes=NUM_CLASSES
)

wandb_logger.watch(model)


### 🚅💪 Train the model using Pytorch Lightning ⚡

In [None]:
MAX_EPOCHS = 2
LOG_EVERY_N_STEPS = 1
CALLBACKS = [checkpoint_callback, early_stopping_callback]


In [None]:
trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    enable_progress_bar=True,
    enable_model_summary=True,
    logger=wandb_logger,
    callbacks=CALLBACKS,
    log_every_n_steps=LOG_EVERY_N_STEPS,
)

trainer.fit(model)


### 🎯🏆 Test the trained model

In [None]:
trainer.test(model, test_loader)


In [None]:
logged_metrics = trainer.logged_metrics

print(f"Logged metrics: {logged_metrics}")
wandb.log(logged_metrics)


In [None]:
y_pred = []
y_true = []

for batch in test_loader:
    x, y = batch
    y_hat = model.net(x.to("cpu"))
    y_pred.append(y_hat.cpu().detach().numpy())
    y_true.append(y.cpu().detach().numpy())

y_pred = np.concatenate(y_pred)
y_true = np.concatenate(y_true)

accuracy_score = accuracy_score(y_true, y_pred.argmax(axis=1))
print(f"Accuracy score: {accuracy_score}")
wandb.log({"accuracy_score": accuracy_score})

classification_report = classification_report(y_true, y_pred.argmax(axis=1))
print(f"Classification report:\n {classification_report}")
wandb.log({"classification_report": classification_report})


In [None]:
cm = confusion_matrix(y_true, y_pred.argmax(axis=1))


fig = px.imshow(cm, labels=dict(x="Predicted label", y="True label", color="Count"))
fig.update_layout(
    title="Confusion matrix",
    xaxis_title="Predicted label",
    yaxis_title="True label",
    annotations=[
        dict(
            x=i,
            y=j,
            text=str(cm[j][i]),
            showarrow=False,
            font=dict(color="white" if cm[j][i] > cm.max() / 2 else "black"),
        )
        for i in range(cm.shape[1])
        for j in range(cm.shape[0])
    ],
    width=800,
    height=800,
)

fig.show()

wandb.log({"confusion_matrix": fig})


In [None]:
wandb.log(
    {
        "classification_report": wandb.Html(
            classification_report(y_true, y_pred.argmax(axis=1))
        )
    }
)
wandb.log(
    {
        "confusion_matrix": wandb.plot.confusion_matrix(
            probs=None, y_true=y_true, preds=y_pred.argmax(axis=1)
        )
    }
)


### 📥 Save the model after training

In [None]:
torch.save(model.state_dict(), "model_after_training.pt")
wandb.save("model_after_training.pt")


### 👏 Finish

In [None]:
wandb.finish()
