# MODEL 3: UNET Training from scratch

### Part 0: Setting up the environment

In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision.io import read_image
import numpy as np
import time, os, json
from tqdm import tqdm
import glob
import kagglehub
import matplotlib.pyplot as plt
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


### Part 1: Configurating datasets and functions

In [2]:
# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Paths configuration
base_path = kagglehub.dataset_download("triminhtran/csce670-segmentation-dataset")
image_dir = os.path.join(base_path, "JPEGImages")
mask_dir = os.path.join(base_path, "SegmentationClass")



### Part 2: Loading dataset and model

In [3]:
class SegmentationDataset(Dataset):
    def __init__(self, image_dir, mask_dir, resize=(448,640)):
        self.image_paths = sorted(glob.glob(os.path.join(image_dir, '*.jpg')))
        self.mask_paths = sorted(glob.glob(os.path.join(mask_dir, '*.png')))
        self.resize = T.Resize(resize, interpolation=T.InterpolationMode.NEAREST)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        mask = Image.open(self.mask_paths[idx]).convert('L')

        img = self.resize(img)
        mask = self.resize(mask)

        img = T.functional.to_tensor(img)
        mask = (T.functional.to_tensor(mask) > 0.05).long().squeeze(0)

        return img, mask

# Metrics calculation
def calculate_metrics(pred, target):
    pred = pred.cpu().numpy()
    target = target.cpu().numpy()

    intersection = np.logical_and(target, pred).sum()
    union = np.logical_or(target, pred).sum()
    pred_sum = pred.sum()
    target_sum = target.sum()

    iou = intersection / union if union != 0 else np.nan
    pixel_acc = (pred == target).mean()
    precision = intersection / pred_sum if pred_sum != 0 else np.nan
    recall = intersection / target_sum if target_sum != 0 else np.nan

    return iou, pixel_acc, precision, recall

In [4]:
# Create dataloaders
dataset = SegmentationDataset(image_dir, mask_dir)
train_size = int(0.7 * len(dataset))
val_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - val_size
train_data, val_data, test_data = random_split(dataset, [train_size, val_size, test_size])

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
val_loader = DataLoader(val_data, batch_size=8)
test_loader = DataLoader(test_data, batch_size=1)

### Part 3: Training

In [5]:
class UNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(UNet, self).__init__()

        self.enc1 = self.conv_block(in_channels, 64)
        self.enc2 = self.conv_block(64, 128)
        self.enc3 = self.conv_block(128, 256)
        self.enc4 = self.conv_block(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = self.conv_block(512, 1024)

        self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = self.conv_block(1024, 512)
        self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = self.conv_block(512, 256)
        self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = self.conv_block(256, 128)
        self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = self.conv_block(128, 64)

        self.conv_final = nn.Conv2d(64, out_channels, kernel_size=1)

    def conv_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        enc1 = self.enc1(x)
        enc2 = self.enc2(self.pool(enc1))
        enc3 = self.enc3(self.pool(enc2))
        enc4 = self.enc4(self.pool(enc3))

        bottleneck = self.bottleneck(self.pool(enc4))

        dec4 = self.dec4(torch.cat((self.upconv4(bottleneck), enc4), dim=1))
        dec3 = self.dec3(torch.cat((self.upconv3(dec4), enc3), dim=1))
        dec2 = self.dec2(torch.cat((self.upconv2(dec3), enc2), dim=1))
        dec1 = self.dec1(torch.cat((self.upconv1(dec2), enc1), dim=1))

        return self.conv_final(dec1)

model = UNet(in_channels=3, out_channels=2).to(device)

In [6]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

num_epochs = 100
train_losses, val_losses = [], []
train_accs, val_accs = [], []

# Training loop
for epoch in range(num_epochs):
    model.train()
    total_loss, total_acc = 0, 0

    for images, masks in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        images, masks = images.to(device), masks.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, masks)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        preds = outputs.argmax(1)
        total_acc += (preds == masks).float().mean().item()

    avg_train_loss = total_loss / len(train_loader)
    avg_train_acc = total_acc / len(train_loader)
    train_losses.append(avg_train_loss)
    train_accs.append(avg_train_acc)

    model.eval()
    val_loss, val_acc = 0, 0
    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)
            loss = criterion(outputs, masks)
            val_loss += loss.item()
            preds = outputs.argmax(1)
            val_acc += (preds == masks).float().mean().item()

    avg_val_loss = val_loss / len(val_loader)
    avg_val_acc = val_acc / len(val_loader)
    val_losses.append(avg_val_loss)
    val_accs.append(avg_val_acc)

    print(f"Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Train Acc: {avg_train_acc:.4f}, Val Acc: {avg_val_acc:.4f}")


Epoch 1/100: 100%|██████████| 47/47 [00:22<00:00,  2.13it/s]


Epoch 1 - Train Loss: 0.3778, Val Loss: 0.1218, Train Acc: 0.7854, Val Acc: 0.9761


Epoch 2/100: 100%|██████████| 47/47 [00:20<00:00,  2.26it/s]


Epoch 2 - Train Loss: 0.0970, Val Loss: 0.1125, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 3/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 3 - Train Loss: 0.0948, Val Loss: 0.1101, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 4/100: 100%|██████████| 47/47 [00:20<00:00,  2.26it/s]


Epoch 4 - Train Loss: 0.0914, Val Loss: 0.1074, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 5/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 5 - Train Loss: 0.0889, Val Loss: 0.1054, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 6/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 6 - Train Loss: 0.0868, Val Loss: 0.1044, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 7/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 7 - Train Loss: 0.0866, Val Loss: 0.1033, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 8/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 8 - Train Loss: 0.0848, Val Loss: 0.1025, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 9/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 9 - Train Loss: 0.0850, Val Loss: 0.1023, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 10/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 10 - Train Loss: 0.0842, Val Loss: 0.1020, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 11/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 11 - Train Loss: 0.0840, Val Loss: 0.1024, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 12/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 12 - Train Loss: 0.0831, Val Loss: 0.1022, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 13/100: 100%|██████████| 47/47 [00:20<00:00,  2.24it/s]


Epoch 13 - Train Loss: 0.0833, Val Loss: 0.1021, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 14/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 14 - Train Loss: 0.0830, Val Loss: 0.1030, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 15/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 15 - Train Loss: 0.0828, Val Loss: 0.1032, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 16/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 16 - Train Loss: 0.0824, Val Loss: 0.1020, Train Acc: 0.9807, Val Acc: 0.9761


Epoch 17/100: 100%|██████████| 47/47 [00:20<00:00,  2.24it/s]


Epoch 17 - Train Loss: 0.0814, Val Loss: 0.1002, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 18/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 18 - Train Loss: 0.0809, Val Loss: 0.1028, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 19/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 19 - Train Loss: 0.0827, Val Loss: 0.1024, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 20/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 20 - Train Loss: 0.0804, Val Loss: 0.1001, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 21/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 21 - Train Loss: 0.0798, Val Loss: 0.0986, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 22/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 22 - Train Loss: 0.0793, Val Loss: 0.0990, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 23/100: 100%|██████████| 47/47 [00:20<00:00,  2.25it/s]


Epoch 23 - Train Loss: 0.0804, Val Loss: 0.0973, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 24/100: 100%|██████████| 47/47 [00:21<00:00,  2.21it/s]


Epoch 24 - Train Loss: 0.0796, Val Loss: 0.0991, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 25/100: 100%|██████████| 47/47 [00:21<00:00,  2.18it/s]


Epoch 25 - Train Loss: 0.0789, Val Loss: 0.0968, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 26/100: 100%|██████████| 47/47 [00:21<00:00,  2.16it/s]


Epoch 26 - Train Loss: 0.0791, Val Loss: 0.0990, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 27/100: 100%|██████████| 47/47 [00:21<00:00,  2.16it/s]


Epoch 27 - Train Loss: 0.0792, Val Loss: 0.0986, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 28/100: 100%|██████████| 47/47 [00:21<00:00,  2.16it/s]


Epoch 28 - Train Loss: 0.0786, Val Loss: 0.0989, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 29/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 29 - Train Loss: 0.0780, Val Loss: 0.0996, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 30/100: 100%|██████████| 47/47 [00:21<00:00,  2.17it/s]


Epoch 30 - Train Loss: 0.0777, Val Loss: 0.0948, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 31/100: 100%|██████████| 47/47 [00:21<00:00,  2.14it/s]


Epoch 31 - Train Loss: 0.0765, Val Loss: 0.0931, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 32/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 32 - Train Loss: 0.0776, Val Loss: 0.0947, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 33/100: 100%|██████████| 47/47 [00:21<00:00,  2.16it/s]


Epoch 33 - Train Loss: 0.0757, Val Loss: 0.0929, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 34/100: 100%|██████████| 47/47 [00:22<00:00,  2.12it/s]


Epoch 34 - Train Loss: 0.0755, Val Loss: 0.0978, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 35/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 35 - Train Loss: 0.0753, Val Loss: 0.0906, Train Acc: 0.9808, Val Acc: 0.9761


Epoch 36/100: 100%|██████████| 47/47 [00:21<00:00,  2.18it/s]


Epoch 36 - Train Loss: 0.0727, Val Loss: 0.0845, Train Acc: 0.9808, Val Acc: 0.9764


Epoch 37/100: 100%|██████████| 47/47 [00:21<00:00,  2.16it/s]


Epoch 37 - Train Loss: 0.0674, Val Loss: 0.0877, Train Acc: 0.9817, Val Acc: 0.9765


Epoch 38/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 38 - Train Loss: 0.0668, Val Loss: 0.0837, Train Acc: 0.9818, Val Acc: 0.9761


Epoch 39/100: 100%|██████████| 47/47 [00:22<00:00,  2.13it/s]


Epoch 39 - Train Loss: 0.0601, Val Loss: 0.0786, Train Acc: 0.9828, Val Acc: 0.9782


Epoch 40/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 40 - Train Loss: 0.0591, Val Loss: 0.0701, Train Acc: 0.9823, Val Acc: 0.9792


Epoch 41/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 41 - Train Loss: 0.0534, Val Loss: 0.0656, Train Acc: 0.9835, Val Acc: 0.9779


Epoch 42/100: 100%|██████████| 47/47 [00:21<00:00,  2.15it/s]


Epoch 42 - Train Loss: 0.0525, Val Loss: 0.0596, Train Acc: 0.9838, Val Acc: 0.9807


Epoch 43/100: 100%|██████████| 47/47 [00:21<00:00,  2.17it/s]


Epoch 43 - Train Loss: 0.0466, Val Loss: 0.0652, Train Acc: 0.9849, Val Acc: 0.9783


Epoch 44/100: 100%|██████████| 47/47 [00:21<00:00,  2.19it/s]


Epoch 44 - Train Loss: 0.0459, Val Loss: 0.0542, Train Acc: 0.9851, Val Acc: 0.9822


Epoch 45/100: 100%|██████████| 47/47 [00:21<00:00,  2.17it/s]


Epoch 45 - Train Loss: 0.0422, Val Loss: 0.0541, Train Acc: 0.9859, Val Acc: 0.9823


Epoch 46/100: 100%|██████████| 47/47 [00:21<00:00,  2.17it/s]


Epoch 46 - Train Loss: 0.0406, Val Loss: 0.0499, Train Acc: 0.9865, Val Acc: 0.9829


Epoch 47/100: 100%|██████████| 47/47 [00:21<00:00,  2.19it/s]


Epoch 47 - Train Loss: 0.0448, Val Loss: 0.0499, Train Acc: 0.9852, Val Acc: 0.9835


Epoch 48/100: 100%|██████████| 47/47 [00:21<00:00,  2.21it/s]


Epoch 48 - Train Loss: 0.0370, Val Loss: 0.0518, Train Acc: 0.9873, Val Acc: 0.9835


Epoch 49/100: 100%|██████████| 47/47 [00:20<00:00,  2.24it/s]


Epoch 49 - Train Loss: 0.0352, Val Loss: 0.0509, Train Acc: 0.9879, Val Acc: 0.9841


Epoch 50/100: 100%|██████████| 47/47 [00:20<00:00,  2.24it/s]


Epoch 50 - Train Loss: 0.0327, Val Loss: 0.0447, Train Acc: 0.9886, Val Acc: 0.9853


Epoch 51/100: 100%|██████████| 47/47 [00:21<00:00,  2.21it/s]


KeyboardInterrupt: 

In [None]:
# Plot training metrics
plt.figure(figsize=(12,4))
plt.subplot(1,2,1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.legend()
plt.title('Loss')

plt.subplot(1,2,2)
plt.plot(train_accs, label='Train Acc')
plt.plot(val_accs, label='Val Acc')
plt.legend()
plt.title('Accuracy')
plt.show()

### Part 4: Evaluating

In [None]:
# Model saving
model.eval()
torch.save(model.state_dict(), "unet_model.pth")
storage_size = os.path.getsize("unet_model.pth") / (1024 ** 2)

# Metrics summary (evaluation and inference metrics)
metrics_summary = {"param_count": sum(p.numel() for p in model.parameters() if p.requires_grad), "storage_size_MB": storage_size,
                   "train_losses": train_losses, "val_losses": val_losses, "train_accs": train_accs, "val_accs": val_accs,
                   "test_loss": [], "test_accuracy": [], "iou": [], "pixel_accuracy": [], "precision": [], "recall": [],
                   "gpu_memory_MB": [], "inference_time_per_image": []}

with torch.no_grad():
    for images, masks in tqdm(test_loader):
        images, masks = images.to(device), masks.to(device)
        torch.cuda.reset_peak_memory_stats(device)
        start_time = time.time()
        outputs = model(images)
        inference_time = time.time() - start_time
        gpu_memory = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
        loss = criterion(outputs, masks)
        preds = outputs.argmax(1)

        iou, pixel_acc, precision, recall = calculate_metrics(preds, masks)
        metrics_summary["test_loss"].append(loss.item())
        metrics_summary["test_accuracy"].append((preds == masks).float().mean().item())
        metrics_summary["iou"].append(iou)
        metrics_summary["pixel_accuracy"].append(pixel_acc)
        metrics_summary["precision"].append(precision)
        metrics_summary["recall"].append(recall)
        metrics_summary["gpu_memory_MB"].append(gpu_memory)
        metrics_summary["inference_time_per_image"].append(inference_time)

# Get test loss and test accuracy
avg_test_loss = np.mean(metrics_summary["test_loss"])
avg_test_accuracy = np.mean(metrics_summary["test_accuracy"])
avg_gpu_memory = np.mean(metrics_summary["gpu_memory_MB"])
avg_inference_time = np.mean(metrics_summary["inference_time_per_image"])

print(f"Average Test Loss: {avg_test_loss:.4f}")
print(f"Average Test Accuracy: {avg_test_accuracy:.4f}")
print(f"Average GPU Memory Usage per Image: {avg_gpu_memory:.2f} MB")
print(f"Average Inference Time per Image: {avg_inference_time:.4f} seconds")

# Save Metrics
with open("unet_metrics.json", "w") as f:
    json.dump(metrics_summary, f, indent=4)

print("Metrics saved to unet_metrics.json")

### Part 5: Visualize a batch

In [None]:
# Visualization of sample predictions
sample_images, sample_masks = next(iter(train_loader))
batch_size = sample_images.shape[0]

with torch.no_grad():
    sample_images_gpu = sample_images.float().to(device)
    output = model(sample_images_gpu)['out']
    predicted_masks = torch.argmax(output, dim=1).cpu()

for i in range(batch_size):
    fig, axs = plt.subplots(1, 3, figsize=(12, 4))

    axs[0].imshow(torch.permute(sample_images[i], (1, 2, 0)))
    axs[0].set_title("Input Image")
    axs[0].axis("off")

    axs[1].imshow(sample_masks[i], cmap='gray')
    axs[1].set_title("Ground Truth Mask")
    axs[1].axis("off")

    axs[2].imshow(predicted_masks[i], cmap='gray')
    axs[2].set_title("Predicted Mask")
    axs[2].axis("off")

    plt.tight_layout()
    plt.show()

### Part 6: Visualize the saved metrics

In [None]:
import json
import matplotlib.pyplot as plt
import numpy as np

# Load metrics from JSON file
with open("unet_metrics.json", "r") as f:
    metrics = json.load(f)

# Plot Training and Validation Loss & Accuracy
plt.figure(figsize=(12, 8))

# Loss
plt.subplot(2, 2, 1)
plt.plot(metrics["train_losses"], label='Train Loss', marker='o')
plt.plot(metrics["val_losses"], label='Validation Loss', marker='o')
plt.title('Loss over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Accuracy
plt.subplot(2, 2, 2)
plt.plot(metrics["train_accs"], label='Train Accuracy', marker='o')
plt.plot(metrics["val_accs"], label='Validation Accuracy', marker='o')
plt.title('Accuracy over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()
plt.grid(True)

# GPU Memory Usage per Test Image
plt.subplot(2, 2, 3)
plt.plot(metrics["gpu_memory_MB"], marker='o', color='blue', linestyle='None')
plt.title('GPU Memory Usage per Image')
plt.xlabel('Test Image Index')
plt.ylabel('GPU Memory (MB)')
plt.grid(True)

# Inference Time per Test Image
plt.subplot(2, 2, 4)
plt.plot(metrics["inference_time_per_image"], marker='o', color='green', linestyle='None')
plt.title('Inference Time per Image')
plt.xlabel('Test Image Index')
plt.ylabel('Inference Time (seconds)')
plt.grid(True)

plt.tight_layout()
plt.show()

# Plot IoU, Pixel Accuracy, Precision, and Recall for Test Set
metric_names = ["iou", "pixel_accuracy", "precision", "recall"]
plt.figure(figsize=(14, 10))

for i, metric_name in enumerate(metric_names, 1):
    plt.subplot(2, 2, i)
    plt.plot(metrics[metric_name], marker='o', linestyle='None')
    plt.title(f'{metric_name.replace("_", " ").title()} per Test Image')
    plt.xlabel('Test Sample')
    plt.ylabel(metric_name.replace("_", " ").title())
    plt.grid(True)

plt.tight_layout()
plt.show()

# Summarize Test Metrics (using nanmean to handle NaNs gracefully)
print(f"Average Test IoU: {np.nanmean(metrics['iou']):.4f}")
print(f"Average Test Pixel Accuracy: {np.mean(metrics['pixel_accuracy']):.4f}")
print(f"Average Test Precision: {np.nanmean(metrics['precision']):.4f}")
print(f"Average Test Recall: {np.nanmean(metrics['recall']):.4f}")

# Summarize GPU and Inference Time Metrics
print(f"Average GPU Memory Usage: {np.mean(metrics['gpu_memory_MB']):.2f} MB")
print(f"Average Inference Time: {np.mean(metrics['inference_time_per_image']):.4f} seconds")

# Print Model Hardware Metrics
print(f"Total Model Parameter Count: {metrics['param_count']:,}")
print(f"Model Storage Size: {metrics['storage_size_MB']:.2f} MB")
