# 🔬 Leukocytes images classification using Vision Transformers

***

## 📁 Set up the environment

#### Installing dependencies

In [None]:
!pip install -q wandb
!pip install -q datasets transformers
!pip install -q plotly-express

#### Importing libraries and logging

In [None]:
from huggingface_hub import notebook_login

notebook_login()


In [None]:
import wandb

wandb.login()


PROJECT = "lcbsi-vits-sweeps"
ENTITY = "polejowska"


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import requests
import seaborn as sns
import tensorflow as tf
import torch
from datasets import load_dataset, load_metric
from PIL import Image
from tensorflow import keras
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from tqdm import tqdm
from transformers import (
    AutoFeatureExtractor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
)

import gdown

sns.set_style("whitegrid")


## 📚 Loading the dataset



In [None]:
from datasets import load_dataset

dataset_path = "lcbsi-wbc-ap"
dataset = load_dataset(f"polejowska/{dataset_path}")


In [None]:
dataset.save_to_disk("lcbsi-wbc-ap")


In [None]:
# log the dataset to wandb as an artifact
with wandb.init(project=PROJECT, entity=ENTITY, job_type="upload") as run:
    artifact = wandb.Artifact("lcbsi-wbc-ap", type="dataset")
    artifact.add_dir(dataset_path)
    run.log_artifact(artifact)
    run.name = "dataset-upload"


In [None]:
# run = wandb.init(project=PROJECT, entity=ENTITY, job_type="download")
# artifact = run.use_artifact('polejowska/lcbsi-wbc-monai-ai/raw_data:v0', type='dataset')
# artifact_dir = artifact.download()
# dataset_path = "artifacts/raw_data-v0"
# run.finish()


In [None]:
# dataset = load_dataset("imagefolder", data_files={"train": "/content/artifacts/raw_data-v0/train/**", "test": "artifacts/raw_data-v0/test/**", "valid": "artifacts/raw_data-v0/valid/**"})


## 🔍 Explore the dataset

In [None]:
print(f"Dataset structure: {dataset}\n")
print(f"Number of training examples: {len(dataset['train'])}\n")
print(f"Dataset sample (image, label): {dataset['train'][0]}\n")
print(f"Dataset features: {dataset['train'].features}\n")
print(f"Class labels: {dataset['train'].features['label'].names}\n")


In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label


#### Visualize data and display essential information

In [None]:
def plot_class_distribution(dataset, id2label, dataset_name="dataset"):
    fig = px.histogram(
        x=[id2label[label] for label in dataset["label"]],
        title=f"Distribution of classes in the {dataset_name}",
    )
    fig.update_layout(xaxis_title="Class", yaxis_title="Number of examples")
    fig.show()
    return fig


In [None]:
train_dataset_fig = plot_class_distribution(dataset["train"], id2label, "train")
valid_dataset_fig = plot_class_distribution(dataset["valid"], id2label, "valid")
test_dataset_fig = plot_class_distribution(dataset["test"], id2label, "test")


In [None]:
def display_random_images(dataset, label2id, id2label):
    fig = plt.figure(figsize=(10, 10))
    for i in range(4):
        random_image = np.random.randint(0, len(dataset))
        image = dataset[random_image]["image"]
        label = dataset[random_image]["label"]
        class_name = id2label[label]

        ax = fig.add_subplot(2, 2, i + 1)
        ax.imshow(image)
        ax.set_title(f"Class: {label} ({class_name})")
        ax.axis("off")
    plt.show()
    return fig


In [None]:
random_images_train = display_random_images(dataset["train"], label2id, id2label)
random_images_valid = display_random_images(dataset["valid"], label2id, id2label)
random_images_test = display_random_images(dataset["test"], label2id, id2label)


In [None]:
def create_table(dataset):
    table = wandb.Table(columns=["image", "label", "class name"])

    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]["image"], dataset[i]["label"]
        table.add_data(wandb.Image(image), label, id2label[label])

    return table


In [None]:
# log the dataset information to wandb
import pandas as pd

with wandb.init(project=PROJECT, entity=ENTITY, job_type="dataset-info") as run:
    run.log({"train-distribution": train_dataset_fig})
    run.log({"valid-distribution": valid_dataset_fig})
    run.log({"test-distribution": test_dataset_fig})
    run.log({"train-random-images": random_images_train})
    run.log({"valid-random-images": random_images_valid})
    run.log({"test-random-images": random_images_test})
    # class labels
    run.log(
        {
            "class-labels": wandb.Table(
                dataframe=pd.DataFrame.from_dict(
                    id2label, orient="index", columns=["class-labels"]
                )
            )
        }
    )
    # number of training examples
    run.log({"number-of-training-examples": len(dataset["train"])})
    # number of validation examples
    run.log({"number-of-validation-examples": len(dataset["valid"])})
    # number of test examples
    run.log({"number-of-test-examples": len(dataset["test"])})
    run.name = "dataset-info"

    train_dataset_table = create_table(dataset["train"])
    valid_dataset_table = create_table(dataset["valid"])
    test_dataset_table = create_table(dataset["test"])

    run.log({"train-dataset": train_dataset_table})
    run.log({"valid-dataset": valid_dataset_table})
    run.log({"test-dataset": test_dataset_table})


***
## 🔨 Data processing

1. Resize images
2. Normalize RGB channels using mean and standard deviation

In [None]:
model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
# model_checkpoint = "facebook/convnext-tiny-224"
# model_checkpoint = "google/vit-base-patch16-224-in21k"
# model_checkpoint = "nickmuchi/vit-base-xray-pneumonia"

feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)


In [None]:
data_transforms = Compose(
    [
        Resize((224, 224)),
        ToTensor(),
        Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
    ]
)


def add_pixel_values_feature(batch):
    batch["pixel_values"] = [
        data_transforms(image.convert("RGB")) for image in batch["image"]
    ]
    return batch


datasets_processed = dataset.rename_column("label", "labels")


In [None]:
train_dataset = datasets_processed["train"].map(add_pixel_values_feature, batched=True)
validation_dataset = datasets_processed["valid"].map(
    add_pixel_values_feature, batched=True
)
test_dataset = datasets_processed["test"].map(add_pixel_values_feature, batched=True)

print(f"Length of training dataset: {len(train_dataset)}")
print(f"Length of validation dataset: {len(validation_dataset)}")
print(f"Length of test dataset: {len(test_dataset)}")


In [None]:
train_dataset.set_transform(add_pixel_values_feature)
validation_dataset.set_transform(add_pixel_values_feature)
test_dataset.set_transform(add_pixel_values_feature)


In [None]:
# save the dataset
train_dataset.save_to_disk("train_dataset")
validation_dataset.save_to_disk("validation_dataset")
test_dataset.save_to_disk("test_dataset")


In [None]:
# move the train_dataset, validation_dataset, test_dataset to dataseet folder
!mv train_dataset dataset/
!mv validation_dataset dataset/
!mv test_dataset dataset/

In [None]:
# upload processed dataset to wandb
with wandb.init(
    project=PROJECT, entity=ENTITY, job_type="upload-processed-dataset"
) as run:
    # upload processed dataset to wandb as an artifact
    artifact = wandb.Artifact("processed-dataset", type="dataset")
    artifact.add_dir("dataset")
    run.log_artifact(artifact)
    run.name = "dataset-processed-upload"


***
## Model configuration and training

In [None]:
def init_model():
    model = AutoModelForImageClassification.from_pretrained(
        model_checkpoint,
        label2id=label2id,
        id2label=id2label,
        ignore_mismatched_sizes=True,
    )
    return model


In [None]:
def compute_metrics_fn(eval_preds):
    metrics = dict()

    accuracy_metric = load_metric("accuracy")
    precision_metric = load_metric("precision")
    recall_metric = load_metric("recall")
    f1_metric = load_metric("f1")

    logits = eval_preds.predictions
    labels = eval_preds.label_ids
    preds = np.argmax(logits, axis=-1)

    metrics.update(accuracy_metric.compute(predictions=preds, references=labels))
    metrics.update(
        precision_metric.compute(
            predictions=preds, references=labels, average="weighted"
        )
    )
    metrics.update(
        recall_metric.compute(predictions=preds, references=labels, average="weighted")
    )
    metrics.update(
        f1_metric.compute(predictions=preds, references=labels, average="weighted")
    )

    return metrics


In [None]:
def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["labels"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}


In [None]:
MODEL_NAME = model_checkpoint.split("/")[-1]


In [None]:
from transformers import Trainer, TrainingArguments


def train(config=None):
    with wandb.init(config=config) as run:
        config = wandb.config

        training_args = TrainingArguments(
            output_dir="vit_train",
            report_to="wandb",
            num_train_epochs=config.epochs,
            learning_rate=config.learning_rate,
            weight_decay=config.weight_decay,
            per_device_train_batch_size=config.batch_size,
            per_device_eval_batch_size=16,
            save_strategy="epoch",
            evaluation_strategy="epoch",
            logging_strategy="epoch",
            load_best_model_at_end=True,
            remove_unused_columns=False,
            fp16=True,
        )

        trainer = Trainer(
            model_init=init_model,
            args=training_args,
            data_collator=collate_fn,
            train_dataset=train_dataset,
            eval_dataset=validation_dataset,
            compute_metrics=compute_metrics_fn,
        )

        trainer.train()

        trainer.evaluate()

        trainer.predict(test_dataset)

        trainer.save_state()


#### Sweep configuration

In [None]:
sweep_config = {"method": "bayes"}

metric = {"name": "eval/loss", "goal": "minimize"}

sweep_config["metric"] = metric


# hyperparameters
parameters_dict = {
    "epochs": {"value": 5},
    "batch_size": {"values": [8, 16, 32, 64]},
    "learning_rate": {"distribution": "log_uniform_values", "min": 1e-5, "max": 1e-3},
    "weight_decay": {"values": [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]},
}


sweep_config["parameters"] = parameters_dict


PROJECT = "lcbsi-vits-sweeps"
sweep_id = wandb.sweep(sweep_config, project=PROJECT)


In [None]:
wandb.agent(sweep_id, train, count=5)


## Inference using transformers pipeline

In [None]:
def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image


img_url = (
    "https://huggingface.co/spaces/polejowska/LCBSI/resolve/main/95-8-24-1_190_1.jpg"
)
image = load_image_from_url(img_url)


In [None]:
from transformers import pipeline

repo_name = "polejowska/swin-tiny-patch4-window7-224-lcbsi-wbc"
pipe = pipeline("image-classification", repo_name)


In [None]:
pipe(image)


In [None]:
wandb.finish()


W&B report: https://wandb.ai/polejowska/vit-classification-lcbsi/reports/Leukocytes-classification-from-blood-smear-images--VmlldzozMTU1NjI0