In [1]:
# Importing the modules and installing data
from pathlib import Path

module_path = Path('./modules/')

if not module_path.is_dir():
    !git clone https://github.com/abhayrokkam/learning-pytorch
    !mv learning-pytorch/06-GoingModular ./modules
    !rm -rf learning-pytorch

In [2]:
import os

import torch
from torch.utils import tensorboard
import torchvision

import torchmetrics

from torchinfo import summary

from modules import data_setup, engine, utils

# Data

In [3]:
# Downloading the data
data_setup.download_data()

data_path = Path('./data/')

train_dir = data_path / 'train'
test_dir = data_path / 'test'

weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
transforms = weights.transforms()

In [4]:
# Getting the dataloaders
batch_size = 32
num_workers = int(os.cpu_count() / 2)

train_dataloader, test_dataloader, class_names = data_setup.get_dataloaders(train_dir=train_dir,
                                                                            test_dir=test_dir,
                                                                            train_transforms=transforms,
                                                                            test_transforms=transforms,
                                                                            batch_size=batch_size,
                                                                            num_workers=num_workers)

# Model

In [5]:
def get_model(model_name: str,
              num_classes: int = len(class_names)) -> torch.nn.Module:
    """
    Retrieves a pre-trained EfficientNet model and modifies its classifier for a specific task.

    This function loads either an EfficientNet-B0 or EfficientNet-B2 model from 
    `torchvision.models`, freezes the feature extraction layers (to prevent training on them), 
    and replaces the classifier to output predictions for the specified number of classes.

    Args:
        model_name (str): The name of the model to retrieve. Only supports 'effnet_b0' or 'effnet_b2'.
        num_classes (int): The number of output classes for the classification task.

    Returns:
        torch.nn.Module: A modified EfficientNet model with a custom classifier.

    Raises:
        NotImplementedError: If the `model_name` is not 'effnet_b0' or 'effnet_b2'.
    """
    if model_name != 'effnet_b0' and model_name != 'effnet_b2':
        raise NotImplementedError("The model name entered can not be retrieved. Only supports 'effnet_b0' or 'effnet_b2'")
    
    if(model_name == 'effnet_b0'):
        weights = torchvision.models.EfficientNet_B0_Weights.DEFAULT
        model = torchvision.models.efficientnet_b0(weights=weights)
        
        # Setting the parameters of the classifer section
        p = 0.2
        in_features = 1280
    
    elif(model_name == 'effnet_b2'):
        weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
        model = torchvision.models.efficientnet_b2(weights=weights)
        
        # Setting the parameters of the classifer section
        p = 0.3
        in_features = 1408
    
    # Freezing the feature extracting layers of the model
    for param in model.features.parameters():
        param.requires_grad = False
    
    # Changing the classifer model for our problem
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(p=p, inplace=True),
        torch.nn.Linear(in_features=in_features,
                        out_features=num_classes)
    )
    
    return model

# Experimentation

### Note:

- Start small and scale up.
- Model sizes, amount of data, number of epochs and with everything.

In [6]:
# Device agnostic
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [7]:
# Experimentation values
num_epochs = [8, 16]

model_names = ['effnet_b0', 'effnet_b2']

lrs = [0.001, 0.0001]

In [8]:
# Loss function and Accuracy
loss_function = torch.nn.CrossEntropyLoss()
accuracy_function = torchmetrics.Accuracy(task='multiclass',
                                          num_classes=len(class_names))

In [9]:
# Experimentation
experiment_number = 0

for epoch in num_epochs:
    for model_name in model_names:
        for lr in lrs:
            
            # Printing an update
            experiment_number += 1
            print(f"[INFO] Experiment number: {experiment_number}")
            print(f"[INFO] Model: {model_name}")
            print(f"[INFO] Epoch: {epoch}")
            print(f"[INFO] Learning Rate: {lr}\n")
            
            model = get_model(model_name)
            
            optimizer = torch.optim.Adam(params=model.parameters(),
                                         lr=lr)
            
            writer = utils.create_writer(experiment_name=f'epoch_{str(epoch)}_lr_{(lr)}',
                                         model_name=model_name)
            
            engine.train(epochs=epoch,
                         model=model,
                         train_dataloader=train_dataloader,
                         test_dataloader=test_dataloader,
                         loss_function=loss_function,
                         optimizer=optimizer,
                         accuracy_function=accuracy_function,
                         device=device,
                         writer=writer)
            
            utils.save_model(model=model,
                             target_dir='./models/',
                             model_name=f'model_{model_name}epoch_{str(epoch)}_lr_{(lr)}.pth')
            
            # Clearing up memory for looping the train process
            del model
            torch.cuda.empty_cache()
            
            print("-"*50 + '\n')

[INFO] Experiment number: 1
[INFO] Model: effnet_b0
[INFO] Epoch: 8
[INFO] Learning Rate: 0.001

[INFO] Created SummaryWriter, saving to: runs/2024-11-11/effnet_b0/epoch_8_lr_0.001...


  0%|          | 0/8 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 0.96  |  Test Loss: 0.68  |  Test Accuracy: 0.88

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 0.70  |  Test Loss: 0.56  |  Test Accuracy: 0.88

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 0.57  |  Test Loss: 0.46  |  Test Accuracy: 0.90

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 0.52  |  Test Loss: 0.43  |  Test Accuracy: 0.92

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.46  |  Test Loss: 0.41  |  Test Accuracy: 0.88

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.41  |  Test Loss: 0.38  |  Test Accuracy: 0.90

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.34  |  Test Loss: 0.34  |  Test Accuracy: 0.91

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.33  |  Test Loss: 0.33 

  0%|          | 0/8 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.11  |  Test Loss: 1.03  |  Test Accuracy: 0.55

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 1.04  |  Test Loss: 0.99  |  Test Accuracy: 0.62

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 1.02  |  Test Loss: 0.96  |  Test Accuracy: 0.70

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 0.98  |  Test Loss: 0.92  |  Test Accuracy: 0.74

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.94  |  Test Loss: 0.89  |  Test Accuracy: 0.79

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.92  |  Test Loss: 0.87  |  Test Accuracy: 0.79

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.91  |  Test Loss: 0.84  |  Test Accuracy: 0.82

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.86  |  Test Loss: 0.82 

Downloading: "https://download.pytorch.org/models/efficientnet_b2_rwightman-c35c1473.pth" to /home/abhayrokkam/.cache/torch/hub/checkpoints/efficientnet_b2_rwightman-c35c1473.pth
100%|██████████| 35.2M/35.2M [00:02<00:00, 15.5MB/s]

[INFO] Created SummaryWriter, saving to: runs/2024-11-11/effnet_b2/epoch_8_lr_0.001...





  0%|          | 0/8 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.02  |  Test Loss: 0.82  |  Test Accuracy: 0.79

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 0.78  |  Test Loss: 0.70  |  Test Accuracy: 0.85

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 0.60  |  Test Loss: 0.59  |  Test Accuracy: 0.88

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 0.54  |  Test Loss: 0.54  |  Test Accuracy: 0.89

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.46  |  Test Loss: 0.50  |  Test Accuracy: 0.89

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.47  |  Test Loss: 0.47  |  Test Accuracy: 0.86

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.44  |  Test Loss: 0.47  |  Test Accuracy: 0.88

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.37  |  Test Loss: 0.44 

  0%|          | 0/8 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.10  |  Test Loss: 1.07  |  Test Accuracy: 0.42

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 1.06  |  Test Loss: 1.04  |  Test Accuracy: 0.53

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 1.04  |  Test Loss: 1.01  |  Test Accuracy: 0.57

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 1.00  |  Test Loss: 0.99  |  Test Accuracy: 0.58

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.98  |  Test Loss: 0.96  |  Test Accuracy: 0.69

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.93  |  Test Loss: 0.94  |  Test Accuracy: 0.71

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.95  |  Test Loss: 0.92  |  Test Accuracy: 0.73

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.90  |  Test Loss: 0.91 

  0%|          | 0/16 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 0.95  |  Test Loss: 0.66  |  Test Accuracy: 0.89

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 0.69  |  Test Loss: 0.53  |  Test Accuracy: 0.93

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 0.59  |  Test Loss: 0.46  |  Test Accuracy: 0.90

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 0.47  |  Test Loss: 0.41  |  Test Accuracy: 0.90

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.40  |  Test Loss: 0.38  |  Test Accuracy: 0.90

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.38  |  Test Loss: 0.36  |  Test Accuracy: 0.91

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.41  |  Test Loss: 0.34  |  Test Accuracy: 0.89

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.40  |  Test Loss: 0.31 

  0%|          | 0/16 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.12  |  Test Loss: 1.05  |  Test Accuracy: 0.48

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 1.07  |  Test Loss: 1.02  |  Test Accuracy: 0.57

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 1.03  |  Test Loss: 0.97  |  Test Accuracy: 0.65

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 1.00  |  Test Loss: 0.94  |  Test Accuracy: 0.72

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.97  |  Test Loss: 0.91  |  Test Accuracy: 0.74

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.94  |  Test Loss: 0.89  |  Test Accuracy: 0.78

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.90  |  Test Loss: 0.86  |  Test Accuracy: 0.81

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.88  |  Test Loss: 0.83 

  0%|          | 0/16 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.00  |  Test Loss: 0.81  |  Test Accuracy: 0.85

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 0.73  |  Test Loss: 0.67  |  Test Accuracy: 0.88

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 0.61  |  Test Loss: 0.58  |  Test Accuracy: 0.87

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 0.57  |  Test Loss: 0.51  |  Test Accuracy: 0.91

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.48  |  Test Loss: 0.48  |  Test Accuracy: 0.88

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.43  |  Test Loss: 0.46  |  Test Accuracy: 0.92

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.40  |  Test Loss: 0.44  |  Test Accuracy: 0.87

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.45  |  Test Loss: 0.40 

  0%|          | 0/16 [00:00<?, ?it/s]


EPOCH: 0 ----------------------------------------------- 

Epoch: 0  |  Loss: 1.10  |  Test Loss: 1.10  |  Test Accuracy: 0.35

EPOCH: 1 ----------------------------------------------- 

Epoch: 1  |  Loss: 1.07  |  Test Loss: 1.07  |  Test Accuracy: 0.44

EPOCH: 2 ----------------------------------------------- 

Epoch: 2  |  Loss: 1.05  |  Test Loss: 1.04  |  Test Accuracy: 0.48

EPOCH: 3 ----------------------------------------------- 

Epoch: 3  |  Loss: 1.00  |  Test Loss: 1.02  |  Test Accuracy: 0.56

EPOCH: 4 ----------------------------------------------- 

Epoch: 4  |  Loss: 0.98  |  Test Loss: 0.99  |  Test Accuracy: 0.64

EPOCH: 5 ----------------------------------------------- 

Epoch: 5  |  Loss: 0.95  |  Test Loss: 0.97  |  Test Accuracy: 0.66

EPOCH: 6 ----------------------------------------------- 

Epoch: 6  |  Loss: 0.95  |  Test Loss: 0.94  |  Test Accuracy: 0.68

EPOCH: 7 ----------------------------------------------- 

Epoch: 7  |  Loss: 0.91  |  Test Loss: 0.93 

- Model with the lowest loss:
    - Model: EffNet-B0
    - Epoch: 16
    - Lr: 0.001
- Model with the best accuracy:
    - Model: EffNet-B2
    - Epoch: 16
    - Lr: 0.001