In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, models, transforms
from torch.utils.data import DataLoader, random_split, Subset
from sklearn.model_selection import train_test_split
import copy
import matplotlib.pyplot as plt
import numpy as np
import torchvision

### We check whether we train with GPU

In [2]:
# Define the device
device = torch.device('mps')

from util import to_f32

# Define the data directory
data_dir = '../data'  # Update this path

# Define transforms for the data
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
    'test': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]),
}

# Load the full dataset
full_dataset = datasets.ImageFolder(data_dir, transform=data_transforms['train'])

# Split the dataset into train, validation, and test sets (70% train, 15% val, 15% test)
train_idx, temp_idx = train_test_split(list(range(len(full_dataset))), test_size=0.3, stratify=full_dataset.targets)
val_idx, test_idx = train_test_split(temp_idx, test_size=0.5, stratify=[full_dataset.targets[i] for i in temp_idx])

# Create subsets for each set
train_dataset = Subset(full_dataset, train_idx)
val_dataset = Subset(full_dataset, val_idx)
test_dataset = Subset(full_dataset, test_idx)

# Apply appropriate transforms to each subset
train_dataset.dataset.transform = data_transforms['train']
val_dataset.dataset.transform = data_transforms['val']
test_dataset.dataset.transform = data_transforms['test']

# Create dataloaders for each set
dataloaders = {
    'train': DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4),
    'val': DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=4),
    'test': DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)
}

# Get dataset sizes
dataset_sizes = {
    'train': len(train_dataset),
    'val': len(val_dataset),
    'test': len(test_dataset)
}


In [3]:
# Get class names
class_names = full_dataset.classes
class_names

['cocci', 'healthy', 'ncd', 'salmo']

## Showcasing an array of images and their labels

In [None]:
# Define a function to show images
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated

# Get a batch of training data
inputs, classes = next(iter(dataloaders['train']))

# Make a grid from batch with larger padding and bigger images
out = torchvision.utils.make_grid(inputs, padding=20, pad_value=1, scale_each=True)

# Display batch with labels
imshow(out, title=[class_names[x] for x in classes])
plt.show()


In [5]:
import torchvision.models.vgg as vgg

# Load a pretrained MobileNetV2 model
model_ft = vgg.vgg19_bn(weights = vgg.VGG19_BN_Weights.DEFAULT)

model_ft

# Modify the classifier to match the number of classes
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, len(class_names))

# Move the model to the appropriate device
model_ft = model_ft.to(device)

# Define the loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU(inplace=True)
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (9): ReLU(inplace=True)
    (10): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (12): ReLU(inplace=True)
    (13): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (14): Conv2d(128, 256

In [6]:
# Function to train and validate the model
from tqdm import tqdm

def train_model(model, criterion, optimizer, num_epochs=25):
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            processed_samples = 0

            # Use tqdm to create a progress bar
            with tqdm(dataloaders[phase], unit='batch') as t:
                for inputs, labels in t:
                    inputs = inputs.to(device)
                    labels = labels.to(device)

                    optimizer.zero_grad()

                    with torch.set_grad_enabled(phase == 'train'):
                        outputs = model(inputs)
                        _, preds = torch.max(outputs, 1)
                        loss = criterion(outputs, labels)

                        if phase == 'train':
                            loss.backward()
                            optimizer.step()

                    running_loss += loss.item() * inputs.size(0)
                    running_corrects += torch.sum(preds == labels.data)
                    processed_samples += inputs.size(0)

                    # Update the progress bar description
                    t.set_description(f'{phase} Loss: {running_loss / processed_samples:.4f} Acc: {running_corrects / processed_samples:.4f}')

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    print(f'Best val Acc: {best_acc:.4f}')
    model.load_state_dict(best_model_wts)
    return model


In [7]:
# Train the model
model_ft = train_model(model_ft, criterion, optimizer, num_epochs=25)

# Save the model
torch.save(model_ft.state_dict(), 'vgg.pth')

Epoch 0/24
----------


train Loss: 0.3067 Acc: 0.8987: 100%|██████████| 149/149 [03:09<00:00,  1.27s/batch]


train Loss: 0.3067 Acc: 0.8987


val Loss: 0.0996 Acc: 0.9599: 100%|██████████| 32/32 [00:34<00:00,  1.07s/batch]


val Loss: 0.0996 Acc: 0.9599

Epoch 1/24
----------


train Loss: 0.0632 Acc: 0.9801: 100%|██████████| 149/149 [03:04<00:00,  1.24s/batch]


train Loss: 0.0632 Acc: 0.9801


val Loss: 0.0807 Acc: 0.9726: 100%|██████████| 32/32 [00:34<00:00,  1.09s/batch]


val Loss: 0.0807 Acc: 0.9726

Epoch 2/24
----------


train Loss: 0.0238 Acc: 0.9937: 100%|██████████| 149/149 [03:26<00:00,  1.38s/batch]


train Loss: 0.0238 Acc: 0.9937


val Loss: 0.0658 Acc: 0.9795: 100%|██████████| 32/32 [00:37<00:00,  1.18s/batch]


val Loss: 0.0658 Acc: 0.9795

Epoch 3/24
----------


train Loss: 0.0129 Acc: 0.9973: 100%|██████████| 149/149 [03:28<00:00,  1.40s/batch]


train Loss: 0.0129 Acc: 0.9973


val Loss: 0.0909 Acc: 0.9697: 100%|██████████| 32/32 [00:35<00:00,  1.10s/batch]


val Loss: 0.0909 Acc: 0.9697

Epoch 4/24
----------


train Loss: 0.0098 Acc: 0.9975: 100%|██████████| 149/149 [03:41<00:00,  1.49s/batch]


train Loss: 0.0098 Acc: 0.9975


val Loss: 0.0758 Acc: 0.9785: 100%|██████████| 32/32 [00:36<00:00,  1.15s/batch]


val Loss: 0.0758 Acc: 0.9785

Epoch 5/24
----------


train Loss: 0.0068 Acc: 0.9985: 100%|██████████| 149/149 [03:23<00:00,  1.37s/batch]


train Loss: 0.0068 Acc: 0.9985


val Loss: 0.0758 Acc: 0.9804: 100%|██████████| 32/32 [00:35<00:00,  1.11s/batch]


val Loss: 0.0758 Acc: 0.9804

Epoch 6/24
----------


train Loss: 0.0036 Acc: 0.9994: 100%|██████████| 149/149 [03:28<00:00,  1.40s/batch]


train Loss: 0.0036 Acc: 0.9994


val Loss: 0.0747 Acc: 0.9814: 100%|██████████| 32/32 [00:36<00:00,  1.15s/batch]


val Loss: 0.0747 Acc: 0.9814

Epoch 7/24
----------


train Loss: 0.0027 Acc: 0.9992: 100%|██████████| 149/149 [03:22<00:00,  1.36s/batch]


train Loss: 0.0027 Acc: 0.9992


val Loss: 0.0709 Acc: 0.9824: 100%|██████████| 32/32 [00:36<00:00,  1.13s/batch]


val Loss: 0.0709 Acc: 0.9824

Epoch 8/24
----------


train Loss: 0.0021 Acc: 0.9998: 100%|██████████| 149/149 [03:25<00:00,  1.38s/batch]


train Loss: 0.0021 Acc: 0.9998


val Loss: 0.0674 Acc: 0.9814: 100%|██████████| 32/32 [00:35<00:00,  1.12s/batch]


val Loss: 0.0674 Acc: 0.9814

Epoch 9/24
----------


train Loss: 0.0029 Acc: 0.9992: 100%|██████████| 149/149 [03:21<00:00,  1.35s/batch]


train Loss: 0.0029 Acc: 0.9992


val Loss: 0.0748 Acc: 0.9795: 100%|██████████| 32/32 [00:37<00:00,  1.16s/batch]


val Loss: 0.0748 Acc: 0.9795

Epoch 10/24
----------


train Loss: 0.0027 Acc: 0.9990: 100%|██████████| 149/149 [35:49<00:00, 14.43s/batch]  


train Loss: 0.0027 Acc: 0.9990


val Loss: 0.0823 Acc: 0.9804: 100%|██████████| 32/32 [00:35<00:00,  1.11s/batch]


val Loss: 0.0823 Acc: 0.9804

Epoch 11/24
----------


train Loss: 0.0020 Acc: 0.9990: 100%|██████████| 149/149 [03:21<00:00,  1.35s/batch]


train Loss: 0.0020 Acc: 0.9990


val Loss: 0.0764 Acc: 0.9804: 100%|██████████| 32/32 [00:37<00:00,  1.16s/batch]


val Loss: 0.0764 Acc: 0.9804

Epoch 12/24
----------


train Loss: 0.0027 Acc: 0.9994: 100%|██████████| 149/149 [03:22<00:00,  1.36s/batch]


train Loss: 0.0027 Acc: 0.9994


val Loss: 0.0757 Acc: 0.9775: 100%|██████████| 32/32 [00:41<00:00,  1.29s/batch]


val Loss: 0.0757 Acc: 0.9775

Epoch 13/24
----------


train Loss: 0.0011 Acc: 0.9998: 100%|██████████| 149/149 [03:24<00:00,  1.37s/batch]


train Loss: 0.0011 Acc: 0.9998


val Loss: 0.0740 Acc: 0.9824: 100%|██████████| 32/32 [00:36<00:00,  1.14s/batch]


val Loss: 0.0740 Acc: 0.9824

Epoch 14/24
----------


train Loss: 0.0013 Acc: 0.9998: 100%|██████████| 149/149 [03:27<00:00,  1.39s/batch]


train Loss: 0.0013 Acc: 0.9998


val Loss: 0.0756 Acc: 0.9795: 100%|██████████| 32/32 [00:35<00:00,  1.12s/batch]


val Loss: 0.0756 Acc: 0.9795

Epoch 15/24
----------


train Loss: 0.0010 Acc: 0.9998: 100%|██████████| 149/149 [03:26<00:00,  1.38s/batch]


train Loss: 0.0010 Acc: 0.9998


val Loss: 0.0754 Acc: 0.9814: 100%|██████████| 32/32 [00:37<00:00,  1.16s/batch]


val Loss: 0.0754 Acc: 0.9814

Epoch 16/24
----------


train Loss: 0.0008 Acc: 0.9998: 100%|██████████| 149/149 [03:22<00:00,  1.36s/batch]


train Loss: 0.0008 Acc: 0.9998


val Loss: 0.0765 Acc: 0.9834: 100%|██████████| 32/32 [00:35<00:00,  1.10s/batch]


val Loss: 0.0765 Acc: 0.9834

Epoch 17/24
----------


train Loss: 0.0009 Acc: 0.9998: 100%|██████████| 149/149 [03:26<00:00,  1.39s/batch]


train Loss: 0.0009 Acc: 0.9998


val Loss: 0.0681 Acc: 0.9824: 100%|██████████| 32/32 [00:35<00:00,  1.11s/batch]


val Loss: 0.0681 Acc: 0.9824

Epoch 18/24
----------


train Loss: 0.0008 Acc: 0.9998: 100%|██████████| 149/149 [03:18<00:00,  1.33s/batch]


train Loss: 0.0008 Acc: 0.9998


val Loss: 0.0720 Acc: 0.9814: 100%|██████████| 32/32 [00:35<00:00,  1.12s/batch]


val Loss: 0.0720 Acc: 0.9814

Epoch 19/24
----------


train Loss: 0.0012 Acc: 0.9996: 100%|██████████| 149/149 [03:21<00:00,  1.35s/batch]


train Loss: 0.0012 Acc: 0.9996


val Loss: 0.0711 Acc: 0.9814: 100%|██████████| 32/32 [00:35<00:00,  1.11s/batch]


val Loss: 0.0711 Acc: 0.9814

Epoch 20/24
----------


train Loss: 0.0005 Acc: 1.0000: 100%|██████████| 149/149 [03:19<00:00,  1.34s/batch]


train Loss: 0.0005 Acc: 1.0000


val Loss: 0.0672 Acc: 0.9824: 100%|██████████| 32/32 [00:37<00:00,  1.17s/batch]


val Loss: 0.0672 Acc: 0.9824

Epoch 21/24
----------


train Loss: 0.0005 Acc: 1.0000: 100%|██████████| 149/149 [03:19<00:00,  1.34s/batch]


train Loss: 0.0005 Acc: 1.0000


val Loss: 0.0687 Acc: 0.9804: 100%|██████████| 32/32 [00:34<00:00,  1.07s/batch]


val Loss: 0.0687 Acc: 0.9804

Epoch 22/24
----------


train Loss: 0.0010 Acc: 0.9996: 100%|██████████| 149/149 [02:58<00:00,  1.20s/batch]


train Loss: 0.0010 Acc: 0.9996


val Loss: 0.0817 Acc: 0.9814: 100%|██████████| 32/32 [00:34<00:00,  1.09s/batch]


val Loss: 0.0817 Acc: 0.9814

Epoch 23/24
----------


train Loss: 0.0009 Acc: 0.9998: 100%|██████████| 149/149 [02:59<00:00,  1.20s/batch]


train Loss: 0.0009 Acc: 0.9998


val Loss: 0.0804 Acc: 0.9804: 100%|██████████| 32/32 [00:34<00:00,  1.09s/batch]


val Loss: 0.0804 Acc: 0.9804

Epoch 24/24
----------


train Loss: 0.0007 Acc: 0.9998: 100%|██████████| 149/149 [02:59<00:00,  1.20s/batch]


train Loss: 0.0007 Acc: 0.9998


val Loss: 0.0790 Acc: 0.9775: 100%|██████████| 32/32 [00:34<00:00,  1.08s/batch]


val Loss: 0.0790 Acc: 0.9775

Best val Acc: 0.9834


### Model evaluation

In [8]:
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score

def evaluate_model(model, dataloader):
    model.eval()
    y_true = []
    y_pred = []

    with torch.no_grad():
        with tqdm(total=len(dataloader.dataset), unit=' samples') as progress_bar:
            for inputs, labels in dataloader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)

                y_true.extend(labels.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())

                progress_bar.update(inputs.size(0))

    return y_true, y_pred

# Evaluate the model on the test set
y_true, y_pred = evaluate_model(model_ft, dataloaders['test'])

# Calculate accuracy
accuracy = accuracy_score(y_true, y_pred)
print(f'Accuracy: {accuracy:.4f}')

# Calculate classification report
class_names = full_dataset.classes
print("Classification Report:")
print(classification_report(y_true, y_pred, target_names=class_names))


100%|██████████| 1022/1022 [00:33<00:00, 30.13 samples/s] 

Accuracy: 0.9843
Classification Report:
              precision    recall  f1-score   support

       cocci       0.99      1.00      1.00       316
     healthy       0.98      0.98      0.98       309
         ncd       0.93      0.96      0.95        56
       salmo       0.99      0.98      0.98       341

    accuracy                           0.98      1022
   macro avg       0.97      0.98      0.98      1022
weighted avg       0.98      0.98      0.98      1022






In [None]:
import torch
import matplotlib.pyplot as plt

def show_examples(model, dataloader, class_names, num_examples=1):
    model.eval()
    device = next(model.parameters()).device

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloader):
            if i >= num_examples:
                break

            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size(0)):
                true_label = class_names[labels[j]]
                predicted_label = class_names[preds[j]]

                # Normalize image
                input_image = inputs[j].cpu()
                image = input_image.permute(1, 2, 0).numpy()

                # Normalize pixel values to [0, 1]
                image = (image - image.min()) / (image.max() - image.min())

                plt.imshow(image)
                plt.title(f'True Label: {true_label}\nPredicted Label: {predicted_label}')
                plt.show()

# Show examples
show_examples(model_ft, dataloaders['test'], class_names, num_examples=1)
