# Frame Classification Model for Fetal Anatomic Structures

## Setup

### Libraries

In [None]:
import os
from collections import defaultdict
from typing import Optional, Iterator
import hashlib
import json
from dataclasses import dataclass

import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

import torchvision.models as models
from torchvision.models.mobilenetv2 import MobileNet_V2_Weights

import albumentations as A

import numpy as np
from numpy.typing import NDArray

from PIL import Image

import matplotlib.pyplot as plt

from tqdm.auto import tqdm

import mlflow
from mlflow.data.dataset import Dataset as MLFLowDataset

from dotenv import load_dotenv


### Configuration

In [None]:
load_dotenv()

In [None]:
%matplotlib tk

In [None]:
np.random.seed = 0

In [None]:
# Model Configuraton
TARGET_HEIGHT = 224
TARGET_WIDTH = 224
# Dataset Configuration
DATASET_DIR = os.getenv("DATASET_DIR")
TRUE_LABELS_DIR = "Standard"
FALSE_LABELS_DIR = "Non Standard"
VERSIONING_FILE_NAME = "Versioning.xlsx"
VERSIONING_PATH = os.path.join(*[DATASET_DIR, VERSIONING_FILE_NAME])
DATASET_VERSION = "V1"
DATASET_TRAIN_VAL_COL = "Train + Val Filenames"
DATASET_TEST_COL = "Test Filenames"
# Experiment logging
MLFLOW_URI = os.getenv("MLFLOW_URI")
MLFLOW_EXPERIMENT_NAME = f"Frame_Classifier_{TARGET_HEIGHT}x{TARGET_WIDTH}"
MLFLOW_USER = os.getenv("MLFLOW_USER")
MODEL_NAME = f"frame_classifier_{TARGET_HEIGHT}x{TARGET_WIDTH}"
TB_LOG_DIR = r"./tb_logs"
# Training parameters
TRAIN_SPLIT, VAL_SPLIT = 0.8, 0.2
K_FOLDS = 5
BATCH_SIZE = 8
NUM_EPOCHS = 100
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
EARLY_STOPPING_PATIENCE = 10
# Temp dirs
CHECKPOINTS_DIR = r"./checkpoints"
if not os.path.exists(CHECKPOINTS_DIR):
    os.makedirs(CHECKPOINTS_DIR)

### Experiment Setup

In [None]:
# Logging with tensorboard and mlflow
if not os.path.exists(TB_LOG_DIR):
    os.makedirs(TB_LOG_DIR, exist_ok=True)
TB_WRITER = SummaryWriter(log_dir=TB_LOG_DIR)
if MLFLOW_URI is None:
    raise RuntimeError("MLFLOW_URI environment variable is not set.")
mlflow.set_tracking_uri(MLFLOW_URI)
MLFLOW_EXPERIMENT = mlflow.set_experiment(MLFLOW_EXPERIMENT_NAME)

In [None]:
# Setup device
DEVICE = "cpu"
if torch.cuda.is_available():
    DEVICE = "cuda"
# check if windows is being used and try to import torch-directml
elif os.name == 'nt':
    try:
        import torch_directml
        DEVICE = torch_directml.device()
    except ImportError:
        # type: ignore
        try:
            ! pip install torch-directml==0.2.4.dev240913
            import torch_directml
            DEVICE = torch_directml.device()
        except Exception as e: # pylint: disable=broad-except
            raise e
    except Exception as e: # pylint: disable=broad-except
        print(f"Error occurred while setting up DirectML: {e}")
print(f"Using device: {DEVICE}")

## Dataset Loading and Preprocessing

In [None]:
ANNOTATIONS: dict[str, set[str]] = {
    TRUE_LABELS_DIR: set(),
    FALSE_LABELS_DIR: set()
}

In [None]:
for dir_name in [TEST_SET_DIR, TRAIN_VAL_SET_DIR]:
    for label in [TRUE_LABELS_DIR, FALSE_LABELS_DIR]:
        full_dir = os.path.join(dir_name, label)
        if not os.path.exists(full_dir):
            raise FileNotFoundError(f"Directory {full_dir} does not exist.")
        ANNOTATIONS[label].update(os.listdir(full_dir))

In [None]:
VERSIONING_DF = pd.read_excel(VERSIONING_PATH, sheet_name=DATASET_VERSION)
TRAIN_VAL_IMG_NAMES = VERSIONING_DF[DATASET_TRAIN_VAL_COL].dropna().tolist()
TEST_IMG_NAMES = VERSIONING_DF[DATASET_TEST_COL].dropna().tolist()

In [None]:
TRAIN_TRANSFORMS = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.3),
    A.Affine(rotate=(-15, 15), p=0.3)
])

In [None]:
@dataclass
class AnatomyDatasetData():
    images: NDArray[np.float32]
    labels: NDArray[np.float32]


class AnatomyDataset(Dataset):
    def __init__(
        self,
        img_dir: str,
        img_names: Optional[list[str]] = None,
    ):
        self.img_dir = img_dir
        self.img_names = sorted(os.listdir(img_dir)) if img_names is None else img_names
        self.n_imgs = len(self.img_names)
        self.data: AnatomyDatasetData = AnatomyDatasetData(
            images=np.zeros((self.n_imgs, TARGET_HEIGHT, TARGET_WIDTH), dtype=np.float32),
            labels=np.zeros(self.n_imgs, dtype=np.float32)
        )
        self.load_data()

    def __len__(self):
        return self.n_imgs
    
    def load_data(self):
        for index, fname in enumerate(self.img_names):
            img_class = TRUE_LABELS_DIR if fname in ANNOTATIONS[TRUE_LABELS_DIR] else FALSE_LABELS_DIR
            img_path = os.path.join(*[self.img_dir, img_class, fname])
            image = np.array(Image.open(img_path).convert('L'), dtype=np.float32)
            image = A.Compose([
                A.Resize(height=TARGET_HEIGHT, width=TARGET_WIDTH),
                A.Normalize(mean=(0.5,), std=(0.5,), max_pixel_value=255.0),
            ])(image=image)['image'] # Does not affect yolo format bboxes
            self.data.images[index] = image
            self.data.labels[index] = 1 if img_class == TRUE_LABELS_DIR else 0

    def get_data(self, idx) -> tuple[np.ndarray, np.float32]:
        return (
            self.data.images[idx],
            self.data.labels[idx]
        )

    def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor]:
        return (
            torch.tensor(self.data.images[idx]).unsqueeze(0),
            torch.tensor(self.data.labels[idx])
        )
    
    def get_class_balance(self) -> dict[str, int]:
        class_counts = defaultdict(int)
        for label in self.data.labels:
            if label == 1:
                class_counts[TRUE_LABELS_DIR] += 1
            else:
                class_counts[FALSE_LABELS_DIR] += 1
        return dict(class_counts)


class AnatomyDatasetSubset(AnatomyDataset):
    def __init__(
            self,
            base_dataset: AnatomyDataset,
            indices: list[int] | NDArray[np.integer],
            transform: Optional[A.Compose] = None,
        ):
        self.base_dataset = base_dataset
        self.transform = transform if transform is not None else A.NoOp()
        self.indices = indices
        self.img_names = [self.base_dataset.img_names[i] for i in self.indices]
        self.n_imgs = len(self.indices)
        self.data: AnatomyDatasetData = AnatomyDatasetData(
            images=np.zeros((self.n_imgs, TARGET_HEIGHT, TARGET_WIDTH), dtype=np.float32
            ),
            labels=np.zeros(self.n_imgs, dtype=np.float32)
        )
        for new_index, original_index in enumerate(self.indices):
            label = self.base_dataset.data.labels[original_index]
            self.data.labels[new_index] = label
        self.transform_data()
    
    def transform_data(self):
        for new_index, original_index in enumerate(self.indices):
            image = self.base_dataset.data.images[original_index]
            transformed_image = self.transform(image=image)['image']
            self.data.images[new_index] = transformed_image

    def __iter__(self) -> Iterator[tuple[torch.Tensor, torch.Tensor]]:
        for idx in range(len(self)):
            yield self[idx]
        self.transform_data()

In [None]:
class ImageListDataset(MLFLowDataset):
    def __init__(
            self,
            names: list[str],
            source: str = VERSIONING_FILE_NAME,
            version: str = DATASET_VERSION
        ):
        self._names = names
        self._source = source
        self._version = version

    def to_dict(self):
        return {
            "name": "image_list_dataset",
            "digest": hashlib.md5(",".join(self._names).encode()).hexdigest(),
            "source_type": "inline",
            "source": self._source,
            "schema": None,
            "profile": json.dumps({
                "version": self._version,
                "num_images": len(self._names),
                "filenames": self._names
            }),
        }

In [None]:
TRAIN_VAL_DATASET = AnatomyDataset(
    img_dir=DATASET_DIR,
    img_names=TRAIN_VAL_IMG_NAMES
)

In [None]:
TEST_DATASET = AnatomyDataset(
    img_dir=DATASET_DIR,
    img_names=TEST_IMG_NAMES
)

In [None]:
print("Train + Val class balance:", TRAIN_VAL_DATASET.get_class_balance())
print("Test class balance:", TEST_DATASET.get_class_balance())

In [None]:
def visualize_sample(
        dataset: AnatomyDataset,
        idxs: list[int]):
    fig, axs = plt.subplots(1, len(idxs), figsize=(10, 10), num="Dataset Sample")
    if not isinstance(axs, np.ndarray):
        axs = np.array([axs])
    for i, idx in enumerate(idxs):
        img, label = dataset[idx]
        img = img.squeeze(0).numpy()  # Convert to 2D array
        img_name = dataset.img_names[idx]
        ax = axs[i] if len(idxs) > 1 else axs
        ax.imshow(img, cmap='gray')
        ax.set_title(f"{img_name}\nLabel: {TRUE_LABELS_DIR if label.item() == 1 else FALSE_LABELS_DIR}")
    fig.tight_layout()
    fig.show()

In [None]:
# sample 3 images from TRAIN_VAL_DATASET
indices = np.random.choice(len(TRAIN_VAL_DATASET), size=3, replace=False)
visualize_sample(TRAIN_VAL_DATASET, indices)

## Model Definition

In [None]:
class DetectionHead(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        # Classifier head: Binary classification
        self.classifier = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.ReLU(),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )

    def forward(self, features: torch.Tensor) -> torch.Tensor:
        probs = self.classifier(features)                       # [B, 1]
        return probs

In [None]:
class FetusAnatomyFrameClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        # Load pretrained MobileNetV2 and adapt for 1 channel
        mobilenet = models.mobilenet_v2(weights=MobileNet_V2_Weights.IMAGENET1K_V1)
        self.features = mobilenet.features
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        # Modify first conv layer to accept grayscale
        first_conv = self.features[0][0]
        new_conv = nn.Conv2d(1, first_conv.out_channels, kernel_size=first_conv.kernel_size,
                             stride=first_conv.stride, padding=first_conv.padding, bias=False)
        new_conv.weight.data = first_conv.weight.data.mean(dim=1, keepdim=True)
        self.features[0][0] = new_conv
        # Head
        self.head = DetectionHead(in_features=1280)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.features(x)                    # [B, 1280, H', W']
        x = self.pool(x).view(x.size(0), -1)    # [B, 1280]
        return self.head(x)

In [None]:
def detection_loss(
        pred_probs: torch.Tensor,
        gt_probs: torch.Tensor,
    ) -> torch.Tensor:
    """
    Combined loss for multi-label classification and bounding box regression.
    """
    bce_loss_fn = nn.BCELoss()
    cls_loss = bce_loss_fn(pred_probs, gt_probs.unsqueeze(1))
    total_loss = cls_loss
    return total_loss

## Training and Evaluation

### Metrics

In [None]:
def accuracy(preds: torch.Tensor, labels: torch.Tensor) -> float:
    preds = (preds >= 0.5).float()
    correct = (preds == labels).float()
    return correct.sum() / len(correct)

### Training Functions

In [None]:
def train_epoch(
    model: FetusAnatomyFrameClassifier,
    train_dataloader: DataLoader,
    optimizer: torch.optim.Optimizer # type: ignore
) -> float:
    """
    Trains the model for one epoch.
    Args:
        model: The neural network model to train.
        dataloader: DataLoader providing training data batches.
        optimizer: Optimizer for updating model parameters.
    Returns:
        Average training loss for the epoch.
    """
    model.train()
    total_loss = 0
    batch_tqdm = tqdm(train_dataloader, desc="Training batches", leave=False, position=2)
    for batch in batch_tqdm:
        imgs: torch.Tensor = batch[0]
        labels: torch.Tensor = batch[1]
        imgs = imgs.to(DEVICE)
        labels = labels.to(DEVICE)
        optimizer.zero_grad()
        cls_pred = model(imgs)
        loss = detection_loss(cls_pred, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    batch_tqdm.close()
    return total_loss / len(train_dataloader)

def validate(
    model: FetusAnatomyFrameClassifier,
    val_loader: DataLoader,
    epoch: Optional[int] = None
) -> tuple[float, float]:
    """
    Evaluates the model on the validation set.
    Args:
        model: The neural network model to evaluate.
        val_loader: DataLoader providing validation data batches.
        epoch: Current epoch number (for logging).
    Returns:
        Tuple of (mean validation loss, accuracy)
    """
    model.eval()
    total_loss = 0
    pred_classes = torch.zeros(len(val_loader.dataset), dtype=torch.float32)
    labels = torch.zeros(len(val_loader.dataset), dtype=torch.float32)
    current_index = 0
    with torch.no_grad():
        for batch in val_loader:
            imgs: torch.Tensor = batch[0]                   # [B, 1, H, W]
            batch_labels: torch.Tensor = batch[1]                 # [B]
            batch_size = batch_labels.size(0)
            imgs = imgs.to(DEVICE)
            batch_pred_classes: torch.Tensor = model(imgs)            # [B, 1]
            loss = detection_loss(batch_pred_classes, batch_labels.to(DEVICE))
            total_loss += loss.item()
            pred_classes[current_index:current_index + batch_size] = (batch_pred_classes.squeeze().cpu() >= 0.5).float()
            labels[current_index:current_index + batch_size] = batch_labels.cpu()
            current_index += batch_size
    acc = accuracy(pred_classes, labels)
    mean_loss = total_loss / len(val_loader)
    if epoch is not None:
        TB_WRITER.add_scalar("val/loss", mean_loss, epoch)
        TB_WRITER.add_scalar("val/accuracy", acc, epoch)
    return mean_loss, acc

In [None]:
def train_one_fold(
    model_constructor: type[FetusAnatomyFrameClassifier],
    train_subset: AnatomyDatasetSubset,
    val_subset: AnatomyDatasetSubset,
    checkpoint_path: str,
    fold_idx: int,
    n_epochs: int,
    batch_size: int,
    learning_rate: float,
    weight_decay: float,
    early_stopping_patience: int
) -> float:
    """
    Trains the model on one fold of the dataset.
    Args:
        model_constructor: Constructor for the model to train.
        train_subset: Training subset of the dataset.
        val_subset: Validation subset of the dataset.
        checkpoint_path: Path to save the best model checkpoint.
        fold_idx: Index of the current fold (for logging).
        n_epochs: Number of epochs to train.
        batch_size: Batch size for training.
        learning_rate: Learning rate for the optimizer.
        weight_decay: Weight decay (L2 regularization) for the optimizer.
        early_stopping_patience: Number of epochs with no improvement to wait before stopping.
    Returns:
        Best validation accuracy achieved during training.
    """
    # Start MLflow run for the entire training process
    # MLFlow run setup
    mlflow_run_id = None
    active_run = mlflow.active_run()
    if active_run is not None:
        mlflow_run_id = active_run.info.run_id
    if mlflow_run_id is not None:
        mlflow.log_params({
            "num_epochs": n_epochs,
            "batch_size": batch_size,
            "learning_rate": learning_rate,
            "weight_decay": weight_decay,
            "model_type": "FetusAnatomyFrameClassifier",
            "backbone": "MobileNetV2"
        }, run_id=mlflow_run_id)
    # Initialization
    model = model_constructor().to(DEVICE)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay) # type: ignore
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.5)
    train_dataloader = DataLoader(train_subset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    val_dataloader = DataLoader(val_subset, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True)
    best_val_acc = 0.0
    patience_counter = 0
    epoch_tqdm = tqdm(range(n_epochs), desc=f"Fold {fold_idx + 1} - Epochs", position=1, leave=True)
    for epoch in epoch_tqdm:
        # Training
        train_loss = train_epoch(model, train_dataloader, optimizer)
        # Validation
        val_loss, val_acc = validate(model, val_dataloader, epoch)
        # Get current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        # Log to TensorBoard
        TB_WRITER.add_scalar("train/loss", train_loss, epoch)
        TB_WRITER.add_scalar("train/learning_rate", current_lr, epoch)
        # Update learning rate scheduler
        scheduler.step(val_loss)
        # Log metrics to MLflow
        if mlflow_run_id is not None:
            mlflow.log_metrics({
                "train_loss": train_loss,
                "val_loss": val_loss,
                "val_acc": val_acc,
                "learning_rate": current_lr
            }, step=epoch, run_id=mlflow_run_id)
        # Save checkpoint if we have a better model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), checkpoint_path)
        else:
            patience_counter += 1
        # Early stopping check
        epoch_description = \
            f"Fold {fold_idx + 1} - "\
            f"Train Loss: {train_loss:.2f} - " \
            f"Val Loss: {val_loss:.2f} - "\
            f"Val Acc: {val_acc:.2f} - "\
            f"LR: {current_lr} - Epochs"
        epoch_tqdm.set_description(epoch_description)
        if patience_counter >= early_stopping_patience:
            break
        # Re-augment training data for next epoch
        train_subset.transform_data()
    if mlflow_run_id is not None:
        mlflow.log_metrics({
            "final_val_loss": val_loss,
            "final_val_acc": val_acc,
            "best_val_acc": best_val_acc
        }, run_id=mlflow_run_id)
        dummy_input = torch.randn(1, 1, TARGET_HEIGHT, TARGET_WIDTH)
        export_path = checkpoint_path.replace(".pth", ".onnx")
        saved_model = model_constructor()
        saved_model.load_state_dict(torch.load(checkpoint_path))
        torch.onnx.export(
            saved_model.cpu(),
            dummy_input,
            export_path,
            input_names=['input'],
            output_names=['output'],
            opset_version=11,
            dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
        )
        mlflow.log_artifact(export_path, run_id=mlflow_run_id)
    return best_val_acc

In [None]:
class dummy_context:
    def __init__(self, *args, **kwargs):  # Accept any arguments
        pass
    def __enter__(self):
        pass
    def __exit__(self, exc_type, exc_value, traceback):
        pass

def train(
    model_constructor: type[FetusAnatomyFrameClassifier],
    dataset: AnatomyDataset,
    train_split: float,
    val_split: float,
    log_runs_to_mlflow: bool,
    n_epochs: int,
    batch_size: int,
    k_folds: int,
    learning_rate,
    weight_decay,
    early_stopping_patience: int
) -> tuple[str, Optional[mlflow.ActiveRun]]:
    """
    Train the model using the specified parameters.
    """
    if log_runs_to_mlflow:
        context = mlflow.start_run
    else:
        context = dummy_context
    base_run_name = f"{MLFLOW_USER}_{pd.Timestamp.now().strftime('%Y%m%d_%H%M%S')}"
    n_images = len(dataset)
    train_start, train_end = 0, int(train_split * n_images)
    val_start, val_end = train_end, int((train_split + val_split) * n_images) if (train_split + val_split) < 1.0 else n_images
    check_points_paths = []
    runs = []
    validation_accs = []
    folds_tqdm = tqdm(range(k_folds), desc="Folds", leave=True, position=0)
    for fold in folds_tqdm:
        run_name = f"{base_run_name}_fold_{fold + 1}"
        with context(run_name=run_name) as run:
            runs.append(run)
            checkpoint_path=os.path.join(CHECKPOINTS_DIR, f"{run_name}.pth")
            check_points_paths.append(checkpoint_path)
            shuffled_indices = np.random.permutation(np.arange(n_images))
            train_indices = shuffled_indices[train_start:train_end]
            val_indices = shuffled_indices[val_start:val_end]
            train_subset = AnatomyDatasetSubset(
                base_dataset=dataset,
                indices=train_indices,
                transform=TRAIN_TRANSFORMS
            )
            val_subset = AnatomyDatasetSubset(
                base_dataset=dataset,
                indices=val_indices,
                transform=None
            )
            if run_name is not None:
                mlflow.log_input(
                    ImageListDataset(
                        names=train_subset.img_names,
                        version=DATASET_VERSION,
                        source=VERSIONING_FILE_NAME
                    ),
                    context="Train Set"
                )
                mlflow.log_input(
                    ImageListDataset(
                        names=val_subset.img_names,
                        version=DATASET_VERSION,
                        source=VERSIONING_FILE_NAME
                    ),
                    context="Validation Set"
                )
            val_acc = train_one_fold(
                model_constructor=model_constructor,
                train_subset=train_subset,
                val_subset=val_subset,
                checkpoint_path=checkpoint_path,
                fold_idx = fold,
                n_epochs=n_epochs,
                batch_size=batch_size,
                learning_rate=learning_rate,
                weight_decay=weight_decay,
                early_stopping_patience=early_stopping_patience
            )
            validation_accs.append(val_acc)
            best_fold = np.argmax(validation_accs)
            folds_tqdm.set_description(f"Best fold: {best_fold + 1}, with accuracy: {validation_accs[best_fold]:.3f} - Folds")
    best_fold = np.argmax(validation_accs)
    return check_points_paths[best_fold], runs[best_fold]

### Training Loop

In [None]:
best_model_checkpoint_path, best_run = train(
    model_constructor=FetusAnatomyFrameClassifier,
    dataset=TRAIN_VAL_DATASET,
    log_runs_to_mlflow=True,
    n_epochs=NUM_EPOCHS,
    batch_size=BATCH_SIZE,
    k_folds=K_FOLDS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    early_stopping_patience=EARLY_STOPPING_PATIENCE,
    train_split=TRAIN_SPLIT,
    val_split=VAL_SPLIT
)

## Evaluation

In [None]:
best_model = FetusAnatomyFrameClassifier().to(DEVICE)
best_model.load_state_dict(torch.load(best_model_checkpoint_path))

In [None]:
@torch.no_grad()
def visualize_prediction(
        model: FetusAnatomyFrameClassifier,
        dataset: AnatomyDataset,
        idxs: list[int]
    ):
    fig, axs = plt.subplots(1, len(idxs), figsize=(4 * len(idxs), 4), num="Dataset Prediction Visualization")
    if not isinstance(axs, np.ndarray):
        axs = np.array([axs])
    for i, idx in enumerate(idxs):
        img, label = dataset[idx]
        pred: torch.Tensor = model(img.unsqueeze(0).to(DEVICE))
        print(f"Predicted probability: {pred.item():.4f}")
        pred_label = TRUE_LABELS_DIR if pred.item() >= 0.5 else FALSE_LABELS_DIR
        img = img.squeeze(0).numpy()  # Convert to 2D array
        img_name = dataset.img_names[idx]
        ax: plt.Axes = axs[i] if len(idxs) > 1 else axs
        ax.imshow(img, cmap='gray')
        ax.set_title(f"{img_name}\nLabel: {TRUE_LABELS_DIR if label.item() == 1 else FALSE_LABELS_DIR}\nPrediction: {pred_label} ({pred.item():.2f})")
    fig.tight_layout()
    fig.show()

In [None]:
idxs = np.random.choice(len(TEST_DATASET), size=5, replace=False)
visualize_prediction(best_model, TEST_DATASET, idxs)

In [None]:
test_dl = DataLoader(TEST_DATASET, BATCH_SIZE, False, num_workers=0, pin_memory=True)
mean_test_loss, mean_test_acc = validate(best_model, test_dl)
print(f"Mean Test Loss: {mean_test_loss:.2f}")
print(f"Mean Test accuracy: {mean_test_acc:.2f}")

## Model Registration

In [None]:
if best_run is not None:
    artifact_name = best_model_checkpoint_path.replace(".pth", ".onnx").split(os.sep)[-1]
    model_version = mlflow.register_model(
        model_uri=f"{MLFLOW_EXPERIMENT.artifact_location}/{best_run.info.run_id}/artifacts/{artifact_name}",
        name=MODEL_NAME
    )