# 🔬 Leukocytes images classification using Vision Transformers
![img](https://assets.stickpng.com/images/6308b84661b3e2a522f01468.png)
***

Computer vision project _Leukocytes classification from blood smear images - LCBSI_


@AgataPolejowska

## 📁 Set up the environment

#### Installing dependencies

In [None]:
!pip install -q wandb
!pip install -q datasets transformers
!pip install -q plotly-express

#### Importing libraries and logging

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
import requests
import seaborn as sns
import torch
from datasets import load_dataset, load_metric
from PIL import Image
from torchvision.transforms import Compose, Normalize, Resize, ToTensor
from tqdm import tqdm
from transformers import (
    AutoFeatureExtractor,
    AutoModelForImageClassification,
    Trainer,
    TrainingArguments,
)

from huggingface_hub import notebook_login
import wandb

In [None]:
notebook_login()

In [None]:
wandb.login()

In [None]:
run = wandb.init(project="vit-test-lcbsi", entity="polejowska")

## 📚 Loading the dataset



In [None]:
# artifact = run.use_artifact("polejowska/lcbsi-wbc-monai-ai/raw_data:v0", type="dataset")
# artifact_dir = artifact.download()
#  dataset_path = "artifacts/raw_data-v0"
# dataset = load_dataset("imagefolder", data_files={"train": "/content/artifacts/raw_data-v0/train/**", "test": "artifacts/raw_data-v0/test/**", "valid": "artifacts/raw_data-v0/valid/**"})

In [None]:
dataset = load_dataset("polejowska/lcbsi-wbc-ap")

## 🔍 Explore the dataset

In [None]:
print(f"Dataset structure: {dataset}\n")
print(f"Number of training examples: {len(dataset['train'])}\n")
print(f"Dataset sample (image, label): {dataset['train'][0]}\n")
print(f"Dataset features: {dataset['train'].features}\n")
print(f"Class labels: {dataset['train'].features['label'].names}\n")


In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()

for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label


In [None]:
wandb.config.update({"class_labels": dataset["train"].features["label"].names})
wandb.config.update({"num_train_examples": len(dataset["train"])})

#### Visualize data and display essential information

In [None]:
def plot_class_distribution(dataset, id2label, dataset_name="dataset"):
    fig = px.histogram(
        x=[id2label[label] for label in dataset["label"]],
        title=f"Distribution of classes in the {dataset_name}",
    )
    fig.update_layout(xaxis_title="Class", yaxis_title="Number of examples")
    fig.show()
    return fig


enitre_dataset_fig = plot_class_distribution(dataset["train"], id2label)
wandb.log({"class distribution in the entire dataset": enitre_dataset_fig})


In [None]:
def display_random_images(dataset, label2id, id2label):
    # display four random images from the dataset using plotly
    fig = plt.figure(figsize=(10, 10))
    for i in range(4):
        random_image = np.random.randint(0, len(dataset))
        image = dataset[random_image]["image"]
        label = dataset[random_image]["label"]
        class_name = id2label[label]

        ax = fig.add_subplot(2, 2, i + 1)
        ax.imshow(image)
        ax.set_title(f"Class: {label} ({class_name})")
        ax.axis("off")
    plt.show()

    wandb.log({"random_images": fig})


display_random_images(dataset["train"], label2id, id2label)


***
## 🔨 Data processing

1. Resize images
2. Normalize RGB channels using mean and standard deviation

In [None]:
# model_checkpoint = "microsoft/swin-tiny-patch4-window7-224"
# model_checkpoint = "facebook/convnext-tiny-224"
# model_checkpoint = "google/vit-base-patch16-224-in21k"
# model_checkpoint = "nickmuchi/vit-base-xray-pneumonia"

In [None]:
model_checkpoint = "polejowska/swin-tiny-patch4-window7-224-lcbsi-wbc"
feature_extractor = AutoFeatureExtractor.from_pretrained(model_checkpoint)

data_transforms = Compose(
    [
        Resize((feature_extractor.size["height"], feature_extractor.size["width"])),
        ToTensor(),
        Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std),
    ]
)


def add_pixel_values_feature(batch):
    batch["pixel_values"] = [
        data_transforms(image.convert("RGB")) for image in batch["image"]
    ]
    return batch


train_dataset = dataset["train"].map(add_pixel_values_feature, batched=True)
validation_dataset = dataset["valid"].map(add_pixel_values_feature, batched=True)
test_dataset = dataset["test"].map(add_pixel_values_feature, batched=True)

print(f"Length of training dataset: {len(train_dataset)}")
print(f"Length of validation dataset: {len(validation_dataset)}")
print(f"Length of test dataset: {len(test_dataset)}")

In [None]:
train_dataset

In [None]:
train_dataset_fig = plot_class_distribution(
    train_dataset, id2label, dataset_name="training dataset"
)
wandb.log({"class distribution in the training dataset": train_dataset_fig})

validation_dataset_fig = plot_class_distribution(
    validation_dataset, id2label, dataset_name="validation dataset"
)
wandb.log({"class distribution in the validation dataset": validation_dataset_fig})

test_dataset_fig = plot_class_distribution(
    test_dataset, id2label, dataset_name="test dataset"
)
wandb.log({"class distribution in the test dataset": test_dataset_fig})


In [None]:
wandb.config.update({"num_train_examples": len(train_dataset)})
wandb.config.update({"num_validation_examples": len(validation_dataset)})
wandb.config.update({"num_test_examples": len(test_dataset)})

In [None]:
train_dataset.set_transform(add_pixel_values_feature)
validation_dataset.set_transform(add_pixel_values_feature)
test_dataset.set_transform(add_pixel_values_feature)

In [None]:
train_dataset

#### W&B - log dataset tables

In [None]:
def create_table(dataset):
    table = wandb.Table(columns=["image", "label", "class name"])

    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]["image"], dataset[i]["label"]
        table.add_data(wandb.Image(image), label, id2label[label])

    return table


In [None]:
train_table = create_table(train_dataset)
validation_table = create_table(validation_dataset)
test_table = create_table(test_dataset)

wandb.log({"train_dataset": train_table})
wandb.log({"validation_dataset": validation_table})
wandb.log({"test_dataset": test_table})


***
## 🚋 Model training

In [None]:
model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint,
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes=True,
)

In [None]:
wandb.watch(model, log="all", log_graph=True)

In [None]:
MODEL_NAME = model_checkpoint.split("/")[-1]
NUM_TRAIN_EPOCHS = 3
LEARNING_RATE = 0.0002562
WEIGHT_DECAY = 0.5
BATCH_SIZE = 32
STRATEGY = "epoch"
wandb.run.name = f"{MODEL_NAME} (epochs: {NUM_TRAIN_EPOCHS}) (lr: {LEARNING_RATE})"


args = TrainingArguments(
    f"{MODEL_NAME}-new",
    remove_unused_columns=False,
    evaluation_strategy=STRATEGY,
    save_strategy=STRATEGY,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    per_device_train_batch_size=BATCH_SIZE,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=BATCH_SIZE,
    num_train_epochs=NUM_TRAIN_EPOCHS,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="wandb",
    push_to_hub=True,
)

In [None]:
def collate_fn(batches):
    pixel_values = torch.stack([batch["pixel_values"] for batch in batches])
    labels = torch.tensor([batch["label"] for batch in batches])
    return {"pixel_values": pixel_values, "labels": labels}

def compute_metrics(eval_preds):
  metrics = dict()
  
  accuracy_metric = load_metric('accuracy')
  precision_metric = load_metric('precision')
  recall_metric = load_metric('recall')
  f1_metric = load_metric('f1')

  logits = eval_preds.predictions
  labels = eval_preds.label_ids
  preds = np.argmax(logits, axis=-1)  
  
  metrics.update(accuracy_metric.compute(predictions=preds, references=labels))
  metrics.update(precision_metric.compute(predictions=preds, references=labels, average='weighted'))
  metrics.update(recall_metric.compute(predictions=preds, references=labels, average='weighted'))
  metrics.update(f1_metric.compute(predictions=preds, references=labels, average='weighted'))

  return metrics

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=feature_extractor,
    data_collator=collate_fn,
    compute_metrics=compute_metrics,
)

In [None]:
trainer_results = trainer.train()

In [None]:
trainer.save_model()

In [None]:
trainer.log_metrics("train", trainer_results.metrics)
trainer.save_metrics("train", trainer_results.metrics)

trainer.save_state()

In [None]:
# save model to W&B
model_artifact = wandb.Artifact(
    f"{MODEL_NAME}-lcbsi-wbc",
    type="model",
    description=f"model trained on {MODEL_NAME} for {NUM_TRAIN_EPOCHS} epochs",
)

wandb.log_artifact(model_artifact)

## 📈 Model evaluation

In [None]:
metrics = trainer.evaluate()

trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)


In [None]:
trainer.push_to_hub()

In [None]:
wandb.finish()

In [None]:
def create_table_with_predictions(dataset, predictions):
    table = wandb.Table(
        columns=["image", "label", "class name", "prediction", "prediction class name"]
    )

    for i in tqdm(range(len(dataset))):
        image, label = dataset[i]["image"], dataset[i]["label"]
        table.add_data(
            wandb.Image(image),
            label,
            id2label[label],
            predictions[i],
            id2label[predictions[i]],
        )

    return table

In [None]:
test_predictions = np.argmax(trainer.predict(test_dataset).predictions, axis=1)
test_table_with_predictions = create_table_with_predictions(
    test_dataset, test_predictions
)

wandb.log({"test_table_with_predictions": test_table_with_predictions})


test_table_with_predictions_artifact = wandb.Artifact(
    name="test_table_with_predictions",
    type="test_table_with_predictions",
    description="A table with predictions on the test dataset",
    metadata={
        "num_test_examples": len(test_dataset),
    },
)

test_table_with_predictions_artifact.add(
    test_table_with_predictions, "test_table_with_predictions"
)
wandb.log_artifact(test_table_with_predictions_artifact)

confusion_matrix = wandb.plot.confusion_matrix(
    probs=None,
    y_true=test_dataset[:]["label"],
    preds=test_predictions,
    class_names=list(id2label.values()),
)

wandb.log({"confusion_matrix": confusion_matrix})

In [None]:
import plotly.graph_objects as go
from sklearn.metrics import confusion_matrix


def plot_confusion_matrix(cm, class_names):
    fig = go.Figure(data=go.Heatmap(z=cm, x=class_names, y=class_names))
    fig.update_layout(
        title="Confusion Matrix",
        xaxis_title="Predicted label",
        yaxis_title="True label",
        annotations=[
            go.layout.Annotation(
                text=str(round(z, 2)), x=x, y=y, font_size=14, showarrow=False
            )
            for x, y, z in zip(
                np.tile(class_names, len(class_names)),
                np.repeat(class_names, len(class_names)),
                cm.flatten(),
            )
        ],
    )
    fig.show()
    return fig


cm = confusion_matrix(test_dataset[:]["label"], test_predictions)
cm = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis]
cm_plot = plot_confusion_matrix(cm, list(id2label.values()))

wandb.log({"confusion_matrix (plotly)": cm_plot})


In [None]:
test_accuracy = accuracy_score(test_dataset[:]["label"], test_predictions)
print("Accuracy: {:.2f}%".format(test_accuracy*100))

## Inference using transformers pipeline

In [None]:
def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content)).convert("RGB")
    return image


img_url = (
    "https://huggingface.co/spaces/polejowska/LCBSI/resolve/main/95-8-24-1_190_1.jpg"
)
image = load_image_from_url(img_url)


In [None]:
from transformers import pipeline

repo_name = "polejowska/swin-tiny-patch4-window7-224-lcbsi-wbc"
pipe = pipeline("image-classification", repo_name)


In [None]:
pipe(image)


In [None]:
wandb.finish()


W&B report: https://wandb.ai/polejowska/vit-classification-lcbsi/reports/Leukocytes-classification-from-blood-smear-images--VmlldzozMTU1NjI0