# 🌿📱 HPDC-Net: Edge Device-Friendly CNN for Plant Leaf Disease Classification 🌾🧠

Welcome to the **HPDC-Net** 👨‍🔬🧪  
A **High-Performance Disease Classification Network (HPDC-Net)** designed for real-time, on-device plant disease detection using leaf images.  

---

## 🌟 Key Features of HPDC-Net

🔹 **Lightweight Architecture** – Perfect for edge deployment on mobile or IoT devices (Raspberry Pi, Jetson Nano).  
🔹 **High Accuracy** – Designed for robust performance on small datasets of plant leaf images.  
🔹 **Optimized Inference** – Fast and efficient with low memory and power usage.  
🔹 **Smart Feature Extraction** – Includes novel modules for capturing disease-specific patterns in leaf texture and color.  




## 📂 Dataset Loading and Preparation 🌿🧪

In this section, we will load and prepare the **Plant Leaf Disease Dataset** for training and evaluation using **PyTorch**. The dataset includes multiple classes such as healthy and diseased leaves from various crops. 🍅🍇🌾


In [None]:
!wget "https://data.mendeley.com/public-files/datasets/tywbtsjrjv/files/d5652a28-c1d8-4b76-97f3-72fb80f94efc/file_downloaded"

--2025-05-18 11:13:37--  https://data.mendeley.com/public-files/datasets/tywbtsjrjv/files/d5652a28-c1d8-4b76-97f3-72fb80f94efc/file_downloaded
Resolving data.mendeley.com (data.mendeley.com)... 162.159.133.86, 162.159.130.86
Connecting to data.mendeley.com (data.mendeley.com)|162.159.133.86|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/d29ed9b2-8a5d-4663-8a82-c9174f2c7066 [following]
--2025-05-18 11:13:38--  https://prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com/d29ed9b2-8a5d-4663-8a82-c9174f2c7066
Resolving prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com (prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com)... 3.5.71.147, 52.92.19.186, 52.218.105.131, ...
Connecting to prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com (prod-dcd-datasets-public-files-eu-west-1.s3.eu-west-1.amazonaws.com)|3.5.71.147|:443..

In [None]:
!unzip "/content/file_downloaded";

In [None]:
# import shutil
# shutil.rmtree('tomato')

In [None]:
import os
import shutil

# Define the directory containing the folders
directory = "/content/Plant_leave_diseases_dataset_without_augmentation"

# Get a list of all folders in the directory
folders = [
    f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))
]

# Iterate through the folders and delete those that do not contain "tomato" in their title
for folder in folders:
    if "tomato" not in folder.lower():
        folder_path = os.path.join(directory, folder)
        print(f"Deleting folder: {folder_path}")
        shutil.rmtree(folder_path)

In [None]:
#split dataset
!pip install split_folders
import splitfolders

base_dir = '/content/Plant_leave_diseases_dataset_without_augmentation'
splitfolders.ratio(base_dir, output='/content/tomato', seed = 1314, ratio = (.7, .2, .1))

train_dir = os.path.join('/content/tomato', 'train')
validation_dir = os.path.join('/content/tomato', 'val')
test_dir = os.path.join('/content/tomato', 'test')

In [None]:
import os
import random
import torch
import numpy as np
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


# Function to set seed for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


# Set seed
seed = 42
set_seed(seed)

# Define the root directory of the dataset
data_dir = "/content/tomato"

# Define transformations for training, validation, and testing sets
train_transforms = transforms.Compose(
    [
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

valid_test_transforms = transforms.Compose(
    [
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

# Load datasets with ImageFolder
train_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "train"), transform=train_transforms
)
valid_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "val"), transform=valid_test_transforms
)
test_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "test"), transform=valid_test_transforms
)

# Define data loaders
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Print the number of classes
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Class names: {train_dataset.classes}")

Number of classes: 10
Class names: ['Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']


In [None]:
!pip install thop

Collecting thop
  Downloading thop-0.1.1.post2209072238-py3-none-any.whl.metadata (2.7 kB)
Downloading thop-0.1.1.post2209072238-py3-none-any.whl (15 kB)
Installing collected packages: thop
Successfully installed thop-0.1.1.post2209072238


🌱🧠 **HPDCNet** is a lightweight convolutional neural network that combines depthwise adaptive residuals, hybrid pooling, and conditional feature enhancement. Designed for efficient and robust visual recognition, it fuses multi-scale features with recalibrated attention for superior performance.
### 🔬 Potential Applications:

✅ Smart Farming  
✅ Mobile Disease Detection Apps  
✅ Precision Agriculture  
✅ Drone Crop Monitoring  


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# --- Block Definitions ---

class DepthwiseAdaptiveResidualBlock(nn.Module):
    """Depthwise separable convolution with adaptive residual connections."""
    def __init__(self, in_channels, out_channels, stride=1):
        super(DepthwiseAdaptiveResidualBlock, self).__init__()
        self.depthwise = nn.Conv2d(
            in_channels, in_channels, kernel_size=3, stride=stride, padding=1, groups=in_channels
        )
        self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) if in_channels != out_channels else nn.Identity()

    def forward(self, x):
        identity = self.shortcut(x)
        x = self.depthwise(x)
        x = self.pointwise(x)
        x = self.bn(x)
        return F.relu(x + identity)


class FeatureRecalibrationAttention(nn.Module):
    """Recalibrates feature importance through channel-wise attention."""
    def __init__(self, in_channels):
        super(FeatureRecalibrationAttention, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        attention = self.sigmoid(self.conv1(x))
        return x * attention


class MultiScaleFusionBlock(nn.Module):
    """Fuses multi-scale features for enhanced representation."""
    def __init__(self, in_channels):
        super(MultiScaleFusionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
        self.conv3 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.conv5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2)

    def forward(self, x):
        x1 = self.conv1(x)
        x3 = self.conv3(x)
        x5 = self.conv5(x)
        return x1 + x3 + x5


class HybridPoolingLayer(nn.Module):
    """Combines max and average pooling for robust feature extraction."""
    def __init__(self, kernel_size=2, stride=2):
        super(HybridPoolingLayer, self).__init__()
        self.max_pool = nn.MaxPool2d(kernel_size, stride)
        self.avg_pool = nn.AvgPool2d(kernel_size, stride)

    def forward(self, x):
        max_pooled = self.max_pool(x)
        avg_pooled = self.avg_pool(x)
        return (max_pooled + avg_pooled) / 2


class ConditionalFeatureEnhancement(nn.Module):
    """Enhances features conditionally based on global context."""
    def __init__(self, in_channels, reduction=16):
        super(ConditionalFeatureEnhancement, self).__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // reduction)
        self.fc2 = nn.Linear(in_channels // reduction, in_channels)

    def forward(self, x):
        b, c, _, _ = x.size()
        weights = self.global_avg_pool(x).view(b, c)
        weights = F.relu(self.fc1(weights))
        weights = torch.sigmoid(self.fc2(weights)).view(b, c, 1, 1)
        return x * weights


# --- HPDC-Net ---

class HPDCNet(nn.Module):
    def __init__(self, num_classes=10):
        super(HPDCNet, self).__init__()
        # Stage 1
        self.layer1 = DepthwiseAdaptiveResidualBlock(3, 16)
        self.attn1 = FeatureRecalibrationAttention(16)
        self.pool1 = HybridPoolingLayer()

        # Stage 2
        self.layer2 = DepthwiseAdaptiveResidualBlock(16, 32)
        self.attn2 = FeatureRecalibrationAttention(32)
        self.pool2 = HybridPoolingLayer()

        # Stage 3
        self.layer3 = DepthwiseAdaptiveResidualBlock(32, 64)
        self.fusion3 = MultiScaleFusionBlock(64)
        self.attn3 = FeatureRecalibrationAttention(64)
        self.pool3 = HybridPoolingLayer()

        # Stage 4
        self.layer4 = DepthwiseAdaptiveResidualBlock(64, 128)
        self.fusion4 = MultiScaleFusionBlock(128)
        self.attn4 = FeatureRecalibrationAttention(128)

        # Classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        # Stage 1
        x = self.layer1(x)
        x = self.attn1(x)
        x = self.pool1(x)

        # Stage 2
        x = self.layer2(x)
        x = self.attn2(x)
        x = self.pool2(x)

        # Stage 3
        x = self.layer3(x)
        x = self.fusion3(x)
        x = self.attn3(x)
        x = self.pool3(x)

        # Stage 4
        x = self.layer4(x)
        x = self.fusion4(x)
        x = self.attn4(x)

        # Classifier
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


# Model Summary
if __name__ == "__main__":
    model = HPDCNet(num_classes=3)
    sample_input = torch.randn(1, 3, 224, 224)
    output = model(sample_input)
    print("Output shape:", output.shape)
    print("Model Summary:")
    from thop import profile

    flops, params = profile(model, inputs=(sample_input,))
    print(f"GFLOPs: {flops / 1e9:.2f}")
    print(f"Number of parameters (in millions): {params / 1e6:.2f}")

In [None]:
!pip install tqdm torchmetrics


### 🏋️ Model Training Procedure

The `HPDCNet` model was trained for 200 epochs using the Adam optimizer with a learning rate of `0.001` and CrossEntropyLoss.  
Training and validation were conducted on a multi-class classification task with 10 classes.

**Training Pipeline:**
- Device: GPU
- Data: Loaded via `train_loader` and `valid_loader`
- Metrics: Accuracy, Precision, Recall, and F1 Score (macro average using TorchMetrics)
- Best model checkpoint saved based on highest validation accuracy


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
import torchmetrics

# Assume the dataset and data loaders are already defined
# Example:
# train_loader = ...
# valid_loader = ...

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the model
num_classes = 10  # Example number of classes
model = HPDCNet(num_classes).to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Training and validation loop
num_epochs = 200

# Metrics
train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(
    device
)
train_precision = torchmetrics.Precision(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)
train_recall = torchmetrics.Recall(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)
train_f1_score = torchmetrics.F1Score(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)

valid_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(
    device
)
valid_precision = torchmetrics.Precision(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)
valid_recall = torchmetrics.Recall(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)
valid_f1_score = torchmetrics.F1Score(
    task="multiclass", num_classes=num_classes, average="macro"
).to(device)

# Variable to track the best validation accuracy
best_val_accuracy = 0.0
best_epoch = 0

for epoch in range(num_epochs):
    # Training
    model.train()
    train_loss = 0
    train_accuracy.reset()
    train_precision.reset()
    train_recall.reset()
    train_f1_score.reset()

    for inputs, labels in tqdm(
        train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"
    ):
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * inputs.size(0)
        train_accuracy.update(outputs, labels)
        train_precision.update(outputs, labels)
        train_recall.update(outputs, labels)
        train_f1_score.update(outputs, labels)

    train_loss /= len(train_loader.dataset)
    train_accuracy_value = train_accuracy.compute().item()
    train_precision_value = train_precision.compute().item()
    train_recall_value = train_recall.compute().item()
    train_f1_value = train_f1_score.compute().item()

    # Validation
    model.eval()
    valid_loss = 0
    valid_accuracy.reset()
    valid_precision.reset()
    valid_recall.reset()
    valid_f1_score.reset()

    with torch.no_grad():
        for inputs, labels in tqdm(
            valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}"
        ):
            inputs, labels = inputs.to(device), labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            valid_loss += loss.item() * inputs.size(0)
            valid_accuracy.update(outputs, labels)
            valid_precision.update(outputs, labels)
            valid_recall.update(outputs, labels)
            valid_f1_score.update(outputs, labels)

    valid_loss /= len(valid_loader.dataset)
    valid_accuracy_value = valid_accuracy.compute().item()
    valid_precision_value = valid_precision.compute().item()
    valid_recall_value = valid_recall.compute().item()
    valid_f1_value = valid_f1_score.compute().item()

    # Save the model if validation accuracy improves
    if valid_accuracy_value > best_val_accuracy:
        best_val_accuracy = valid_accuracy_value
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "best_model.pth")
        print(
            f"New best model saved at epoch {epoch+1} with validation accuracy: {best_val_accuracy:.4f}"
        )

    print(f"Epoch {epoch+1}/{num_epochs}")
    print(f"Training Loss: {train_loss:.4f} | Validation Loss: {valid_loss:.4f}")
    print(
        f"Training Accuracy: {train_accuracy_value:.4f} | Validation Accuracy: {valid_accuracy_value:.4f}"
    )
    print(
        f"Training Precision: {train_precision_value:.4f} | Validation Precision: {valid_precision_value:.4f}"
    )
    print(
        f"Training Recall: {train_recall_value:.4f} | Validation Recall: {valid_recall_value:.4f}"
    )
    print(
        f"Training F1 Score: {train_f1_value:.4f} | Validation F1 Score: {valid_f1_value:.4f}"
    )

print(f"Best validation accuracy: {best_val_accuracy:.4f} at epoch {best_epoch}")

### 🏋️ Model Training on CPU


In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchmetrics
import time
import matplotlib.pyplot as plt
import numpy as np
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from sklearn.metrics import confusion_matrix
import seaborn as sns
import random

# Define the test dataset with ImageFolder to get class names
test_dataset = datasets.ImageFolder(
    os.path.join(data_dir, "test"), transform=valid_test_transforms
)
class_names = test_dataset.classes  # Extract class names from the folder names

# Force CPU usage
device = torch.device("cpu")

# Load the trained model
num_classes = len(class_names)
model = HPDCNet(num_classes).to(device)
model_path = "/content/best_model.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# Define the test data loader
# test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# Metrics
test_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
test_precision = torchmetrics.Precision(task="multiclass", num_classes=num_classes, average="macro").to(device)
test_recall = torchmetrics.Recall(task="multiclass", num_classes=num_classes, average="macro").to(device)
test_f1_score = torchmetrics.F1Score(task="multiclass", num_classes=num_classes, average="macro").to(device)
test_confusion_matrix = torchmetrics.ConfusionMatrix(task="multiclass", num_classes=num_classes).to(device)

# Function to perform inference and calculate metrics
def perform_inference(device):
    model.to(device)
    test_accuracy.to(device)
    test_precision.to(device)
    test_recall.to(device)
    test_f1_score.to(device)
    test_confusion_matrix.to(device)

    inference_times = []
    all_labels = []
    all_predictions = []

    for i in range(5):
        test_accuracy.reset()
        test_precision.reset()
        test_recall.reset()
        test_f1_score.reset()
        test_confusion_matrix.reset()

        start_time = time.time()
        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs, 1)

                test_accuracy.update(outputs, labels)
                test_precision.update(outputs, labels)
                test_recall.update(outputs, labels)
                test_f1_score.update(outputs, labels)
                test_confusion_matrix.update(predicted, labels)

                all_labels.extend(labels.cpu().numpy())
                all_predictions.extend(predicted.cpu().numpy())

        end_time = time.time()
        inference_times.append((end_time - start_time) * 1000)  # ms

    avg_inference_time = np.mean(inference_times) / len(test_loader.dataset)
    avg_accuracy = test_accuracy.compute().item()
    avg_precision = test_precision.compute().item()
    avg_recall = test_recall.compute().item()
    avg_f1 = test_f1_score.compute().item()
    cm = test_confusion_matrix.compute().cpu().numpy()

    return avg_accuracy, avg_precision, avg_recall, avg_f1, avg_inference_time, cm, all_labels, all_predictions

# Perform inference on CPU
(
    avg_accuracy_cpu,
    avg_precision_cpu,
    avg_recall_cpu,
    avg_f1_cpu,
    avg_inference_time_cpu,
    cm_cpu,
    labels_cpu,
    preds_cpu,
) = perform_inference(device)

fps_cpu = 1000 / avg_inference_time_cpu
print(f"CPU Average Test Accuracy: {avg_accuracy_cpu:.4f}")
print(f"CPU Average Test Precision: {avg_precision_cpu:.4f}")
print(f"CPU Average Test Recall: {avg_recall_cpu:.4f}")
print(f"CPU Average Test F1 Score: {avg_f1_cpu:.4f}")
print(f"CPU Average Inference Time per Image: {avg_inference_time_cpu:.6f} ms")
print(f"CPU Frames per Second: {fps_cpu:.2f}")

# Display confusion matrix
def plot_confusion_matrix(cm, class_names, title):
    plt.figure(figsize=(10, 8))
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=class_names,
        yticklabels=class_names,
    )
    plt.xlabel("Predicted Labels")
    plt.ylabel("True Labels")
    plt.title(title)
    plt.show()

plot_confusion_matrix(cm_cpu, class_names, "CPU Confusion Matrix")


## 🧪 Visualizing Model Predictions on Sample Test Images

To better understand how the model performs on real data, we selected a set of 20 representative test images—two from each class (when available).

For each image, we performed forward inference and compared the **predicted label** against the **true label**:

- ✅ Correct predictions are shown with **green titles**
- ❌ Incorrect predictions are highlighted in **red**

The images are displayed in a grid layout for quick inspection. This visualization helps:

- Evaluate **class-wise performance**
- Identify **failure patterns**
- Visually inspect prediction **confidence and mistakes**


In [None]:
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
from torchvision import transforms

# Ensure device is correctly set
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# Collect images by class
def collect_images_by_class(test_loader):
    class_indices = {i: [] for i in range(num_classes)}
    all_labels = []
    all_images = []

    for inputs, labels in test_loader:
        all_images.extend(inputs)
        all_labels.extend(labels)

    all_images = torch.stack(all_images).to(device)
    all_labels = torch.tensor(all_labels).to(device)

    for idx, label in enumerate(all_labels):
        class_indices[label.item()].append(idx)

    return class_indices, all_images, all_labels


# Get images by class
class_indices, all_images, all_labels = collect_images_by_class(test_loader)

# Select two images per class
selected_indices = []
for indices in class_indices.values():
    if len(indices) >= 2:
        selected_indices.extend(random.sample(indices, 2))
    else:
        selected_indices.extend(indices)

# Ensure we have 20 images
if len(selected_indices) < 20:
    remaining_indices = [i for i in range(len(all_labels)) if i not in selected_indices]
    additional_indices = random.sample(remaining_indices, 20 - len(selected_indices))
    selected_indices.extend(additional_indices)

# Gather the selected images and labels
selected_images = all_images[selected_indices]
selected_labels = all_labels[selected_indices]

# Perform inference on the selected images
with torch.no_grad():
    outputs = model(selected_images)
    _, predicted_labels = torch.max(outputs, 1)


# Plot images
def plot_image_with_labels(image, predicted_label, true_label, ax, class_names):
    image = inv_transform(image.cpu())
    ax.imshow(image)
    if predicted_label == true_label:
        ax.set_title(
            f"Pred: {class_names[predicted_label]}, True: {class_names[true_label]}",
            color="green",
        )
    else:
        ax.set_title(
            f"Pred: {class_names[predicted_label]}, True: {class_names[true_label]}",
            color="red",
        )
    ax.axis("off")


# Inverse transform for displaying images
inv_transform = transforms.Compose(
    [
        transforms.Normalize(
            mean=[-0.485 / 0.229, -0.456 / 0.224, -0.406 / 0.225],
            std=[1 / 0.229, 1 / 0.224, 1 / 0.225],
        ),
        transforms.ToPILImage(),
    ]
)

# Plotting the selected images
fig, axes = plt.subplots(4, 5, figsize=(15, 12))

for idx, ax in zip(range(len(selected_images)), axes.flat):
    plot_image_with_labels(
        selected_images[idx],
        predicted_labels[idx].item(),
        selected_labels[idx].item(),
        ax,
        class_names,
    )

plt.tight_layout()
plt.show()

In [None]:
import zipfile
import os
from google.colab import files

# Path to your model file
model_file_path = "/content/best_model.pth"

# Path to save the zip file
zip_file_path = "/content/best_model.zip"

# Create a zip file and add the model file to it
with zipfile.ZipFile(zip_file_path, "w") as zipf:
    zipf.write(model_file_path, os.path.basename(model_file_path))

# Download the zip file to your local machine
files.download(zip_file_path)

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>