### Project insights

In [42]:
import os

print(f"Working directory : {os.getcwd()}")
print(f"What's inside : {os.listdir('.')}")

Working directory : c:\Users\Sacha\Documents\vision_project
What's inside : ['data', 'models', 'modularity.ipynb', 'notebooks', 'src', 'train.py', 'train_CNN.py', 'venv']


### Getting data

In [26]:
import os
import requests
import zipfile
from pathlib import Path

# Setup path to data folder
data_path = Path("data/")

# If the image folder doesn't exist, download it and prepare it... 
if data_path.is_dir():
    print(f"{data_path} directory exists.")
else:
    print(f"Did not find {data_path} directory, creating one...")
    data_path.mkdir(parents=True, exist_ok=True)

# Download pizza, steak, sushi data
with open(data_path / "pizza_steak_sushi.zip", "wb") as f:
    request = requests.get("https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip")
    print("Downloading pizza, steak, sushi data...")
    f.write(request.content)

# Unzip pizza, steak, sushi data
with zipfile.ZipFile(data_path / "pizza_steak_sushi.zip", "r") as zip_ref:
    print("Unzipping pizza, steak, sushi data...") 
    zip_ref.extractall(data_path)

# Remove zip file
os.remove(data_path / "pizza_steak_sushi.zip")

data directory exists.
Downloading pizza, steak, sushi data...
Unzipping pizza, steak, sushi data...


#### Making the DataLoader generator

In [62]:
%%writefile src/data_setup.py
"""
File to create PyTorch DataLoaders 
"""

import os

from torch.utils.data import DataLoader
from torchvision import datasets, transforms

NUM_WORKERS = os.cpu_count()

def create_dataloaders(train_dir: str, 
                       test_dir: str,
                       transform: transforms.Compose,
                       batch_size: int,
                       num_workers: int=NUM_WORKERS,
                       dataset_type: str = None)->DataLoader:
    
    
    """
    Create for each dataset a set of iterables batches
    
    Arguments:
    #  - train_dir: Path to the trai data directory
    #  - test_dir: Path to the test data directory
     - transform: Torchvision transforms to perform on data
     - batch_size: Size of each batch (How many images per batch) in DataLoader
     - num_workers: Number of workers per DataLoader. Default all threads available.

    Returns:
     - A tuple of (train_dataloader, test_dataloader, class_names).
     Where class_names is a list of the target classes.
    """
    if dataset_type == "cifar10":
        train_data = datasets.CIFAR10(root="data", train=True, download=True, transform=transform)
        test_data = datasets.CIFAR10(root="data", train=False, download=True, transform=transform)
    else:
        # Ton code actuel pour les dossiers locaux
        train_data = datasets.ImageFolder(root=train_dir, transform=transform)
        test_data = datasets.ImageFolder(root=test_dir, transform=transform)
    
    # Get class names
    class_names = train_data.classes

    # Turn each dataset into a set of iterables batches (DataLoaders)
    train_dataloader = DataLoader(dataset=train_data,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  num_workers=num_workers)

    test_dataloader = DataLoader(dataset=test_data,
                                batch_size=batch_size,
                                shuffle=False,
                                num_workers=num_workers)
    
    return (train_dataloader, test_dataloader, class_names)

Overwriting src/data_setup.py


#### Making the CNN model builder

In [61]:
%%writefile src/CNN_builder.py
""" 
File for Pytorch code of the CNN model to instantiate.
"""

import torch
from torch import nn

# CNN Model
class CNN(nn.Module):
    """
    Model architecture from https://poloclub.github.io/cnn-explainer/.
    TinyVGG adaptation.
    """
    def __init__(self,
                 input_shape: int,
                 hidden_units: int,
                 output_shape: int):
        """
        Initializes the CNN model layers.

        Arguments:
          - input_shape: Number of input color channels.
          - hidden_units: Number of hidden units (filters) per convolutional layer.
          - output_shape: Number of output units/classes.
        """
        super().__init__()

        self.conv_block1 = nn.Sequential(
            nn.Conv2d(in_channels=input_shape,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.conv_block2 = nn.Sequential(
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=hidden_units,
                      out_channels=hidden_units,
                      kernel_size=3,
                      stride=1,
                      padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            # Assuming input images are 32x32. After two MaxPool2d (size/2), 
            # the feature map size is 8x8.
            nn.Linear(in_features=hidden_units*8*8,
                      out_features=output_shape)
        )

    def forward(self, x):
        x = self.conv_block1(x)
        x = self.conv_block2(x)
        x = self.classifier(x)
        return x

Overwriting src/CNN_builder.py


### Making the ViT model builder

In [60]:
%%writefile src/ViT_builder.py
""" 
File for Pytorch code of the ViT model to instantiate.
"""

import torch
from torch import nn
from torch import functional as F
from torch.utils.data import DataLoader

import torchvision
from torchvision import transforms, datasets

COLOR_CHANNELS = 3
IMG_SIZE = 224
PATCH_SIZE = 16
EMBEDDING_DIM = 768
BATCH_SIZE = 12
NUM_HEADS = 12
MLP_SIZE = 3072
TRANSFORMER_LAYER_NUM = 12
EMBEDDING_DROPOUT = 0.1
CLASSES_NUM=1000
PATCH_NUMBER = ( IMG_SIZE * IMG_SIZE ) // PATCH_SIZE ** 2

class PatchEmbedding(nn.Module):
  """
  Turns a 2D input image into a 1D set of embedded patches.
  """
  def __init__(self,
               in_channels=COLOR_CHANNELS,
               patch_size=PATCH_SIZE,
               embedding_dim=EMBEDDING_DIM):
    """
    Arguments:
      - in_channels = Number of color channel for the input image. Default 3.
      - patch_size = Size of the patches to convert input image into. Default 16.
      - embedding_dim = Size of the embedding vector to turn image into. Default 768.
    """
    super().__init__()

    self.patch_size = patch_size

    self.patcher = nn.Conv2d(in_channels=in_channels,
                             out_channels=embedding_dim,
                             kernel_size=patch_size,
                             stride=patch_size,
                             padding=0) # No padding here

    self.flatter = nn.Flatten(start_dim=2, end_dim=3)

  def forward(self, x):
    # Prior size verification
    img_res = x.shape[-1]
    assert img_res % self.patch_size == 0, "Image resolution must be divisible by the patch size"

    x_patches = self.patcher(x)
    x_flattened = self.flatter(x_patches)
    x_embedded = x_flattened.permute(0, 2, 1)
    return x_embedded
  
class MultiHeadAttentionBlock(nn.Module):
  """
  Implements the multi head self attention block of the trasformer encoder.
  """
  def __init__(self,
                embedding_dim=EMBEDDING_DIM,
                num_heads=NUM_HEADS,
                attn_dropout:float=0):
    """
    Arguments:
      -embedding_dim: The constant latent vector size D used throughout the Transformer.
      -num_heads: Number of attention heads (k).
      -attn_dropout: Dropout probability applied to the attention weights.
    """
    super().__init__()

    self.normalizer = nn.LayerNorm(normalized_shape=embedding_dim)

    self.multihead_attn = nn.MultiheadAttention(embed_dim=embedding_dim,
                                                 num_heads=num_heads,
                                                 dropout=attn_dropout,
                                                 batch_first=True)

  def forward(self, x):
    x = self.normalizer(x)
    attn_output, _ = self.multihead_attn(query=x,
                                         key=x,
                                         value=x)
    return attn_output

class MLPBlock(nn.Module):
  """
  Implements the MLP block of the transformer encoder.
  """
  def __init__(self,
               embedding_dim=EMBEDDING_DIM,
               mlp_size=MLP_SIZE,
               mlp_dropout:float=0):
    super().__init__()

    self.normalizer = nn.LayerNorm(normalized_shape=embedding_dim)

    self.mlp = nn.Sequential(
        nn.Linear(in_features=embedding_dim, out_features=mlp_size),
        nn.GELU(),
        nn.Dropout(p=mlp_dropout),
        nn.Linear(in_features=mlp_size, out_features=embedding_dim),
        nn.Dropout(p=mlp_dropout))

  def forward(self, x):
    x = self.normalizer(x)
    x = self.mlp(x)
    return x

class TransformerEncoder(nn.Module):
  """
  Create Transformer encoder block.
  """
  def __init__(self,
               embedding_dim=EMBEDDING_DIM,
               num_heads=NUM_HEADS,
               mlp_size=MLP_SIZE,
               attn_dropout:float=0,
               mlp_dropout:float=0):
    super().__init__()

    self.msa_block = MultiHeadAttentionBlock(embedding_dim=embedding_dim,
                                             num_heads=num_heads,
                                             attn_dropout=attn_dropout)

    self.mlp_block = MLPBlock(embedding_dim=embedding_dim,
                              mlp_size=mlp_size,
                              mlp_dropout=mlp_dropout)

  def forward(self, x):
    x = self.msa_block(x) + x
    x = self.mlp_block(x) + x
    return x

class ViT(nn.Module):
  """
  Create Vision Transformer architecture model.
  """
  def __init__(self,
               img_size=IMG_SIZE, # Training resolution
               in_channels=COLOR_CHANNELS, # Number of color channels in input image
               patch_size=PATCH_SIZE, # Patch size
               transformer_layer_num=TRANSFORMER_LAYER_NUM, # Number of ViT layers from ViT paper table
               embedding_dim=EMBEDDING_DIM, # Hidden D size from ViT paper table
               mlp_size=MLP_SIZE, # MLP size from ViT paper table
               num_heads=NUM_HEADS, # Number of heads for MSA from ViT paper table
               attn_dropout:float=0, # Dropout for attention from ViT paper table
               mlp_dropout:float=0, # Dropout for MLP layers from ViT paper table
               embedding_dropout=EMBEDDING_DROPOUT, # Dropout for patch and positional embedding
               num_classes=CLASSES_NUM): # Number of classes to predict
    super().__init__()

    # Make sure the image size is divisible by the patch size
    assert img_size % patch_size == 0, "Image resolution must be divisible by the patch size"

    # Number of patches
    self.num_patches = (img_size * img_size) // patch_size ** 2

    # Create learnable class embedding
    self.class_embedding = nn.Parameter(torch.randn(1, 1, embedding_dim))

    # Create learnable positional embedding
    self.positional_embedding = nn.Parameter(torch.randn(1, self.num_patches + 1, embedding_dim))

    # Dropout value for patch and positional embedding
    self.embedding_dropout = nn.Dropout(p=embedding_dropout)

    # Create patch embedding layer
    self.patch_embedding = PatchEmbedding(in_channels=in_channels,
                                          patch_size=patch_size,
                                          embedding_dim=embedding_dim)

    # Create Transformer blocks
    self.transformer_layer = nn.Sequential(*[TransformerEncoder(embedding_dim=embedding_dim,
                                                                num_heads=num_heads,
                                                                mlp_size=mlp_size,
                                                                attn_dropout=attn_dropout,
                                                                mlp_dropout=mlp_dropout) for _ in range(transformer_layer_num)])

    # Create classifier head
    self.classifier = nn.Sequential(
      nn.LayerNorm(normalized_shape=embedding_dim),
      nn.Linear(in_features=embedding_dim, out_features=num_classes))

  def forward(self, x):
    # Get batch size
    batch_size = x.shape[0]

    # Create class token embeddding and expand it to the batch size
    class_token = self.class_embedding.expand(batch_size, -1, -1)

    # Apply patch embedding
    x = self.patch_embedding(x)

    # Concatenate class embedding and patch embedding
    x = torch.cat((class_token, x), dim=1)

    # Add positional embedding
    x = x + self.positional_embedding

    # Apply dropout to embedding part
    x = self.embedding_dropout(x)

    # Pass patch, class and positional embedding through the tranformer blocks
    x = self.transformer_layer(x)

    # 0 logit for classifier
    x = self.classifier(x[:, 0])
    
    return x
    

Overwriting src/ViT_builder.py


#### Making the model train and test functions

In [59]:
%%writefile src/engine.py
""" 
File for training and testing our models.
"""

import torch

from torchmetrics import Accuracy

from tqdm.auto import tqdm
from typing import List, Dict, Tuple

def train_step(model: troch.nn.Module,
               dataloader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               acc_fn: torchmetrics.Accuracy,
               optimizer: torch.optim.Optimizer,
               device: torch.device)->Tuple[float, float]:
    """ 
    Trains a PyTorch model for a single epoch.

    Arguments:
        - model: Pytorch model to be trained.
        - dataloder: DataLoader for the model to be trained on.
        - loss_fn: Pytorch criterion to minimize.
        - acc_fn: Pytorch accuracy metric.
        - optimizer: Optimize to help minimize the loss function.
        - device: Target device to compute on.
    
    Returns:
        - Tuple saving train training loss and train accuracy metrics.
    """

    # Put model in training mode
    model.train()

    # Initialization of train loss and accuracy
    train_loss, train_acc = 0, 0

    for batch, (X_train, y_train) in enumerate(dataloader):
        # Send data to target device
        X_train, y_train = X_train.to(device), y_train.to(device)

        # Forward pass
        y_train_pred = model(X_train)

        # Calculate loss and accuracy
        loss = loss_fn(y_train_pred, y_train)
        train_loss += loss

        accuracy = acc_fn(y_train_pred,y_train)
        train_acc += accuracy
        
        # Optimizer zero grad
        optimizer.zero_grad()

        # Loss backward
        loss.backward()

        # Optimizer step
        optimizer.step()
    
    # Adjust metrics to get average loss and accuracy per batch 
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)

    return (train_loss, train_acc)

def test_step(model: troch.nn.Module,
            dataloader: torch.utils.data.DataLoader,
            loss_fn: torch.nn.Module,
            acc_fn: torchmetrics.Accuracy,
            device: torch.device)->Tuple[float, float]:
    """ 
    Tests a PyTorch model for a single epoch.

    Arguments:
        - model: Pytorch model to be tested.
        - dataloder: DataLoader for the model to be tested on.
        - loss_fn: Pytorch criterion to minimize.
        - acc_fn: Pytorch accuracy metric.
        - device: Target device to compute on.
    
    Returns:
        - Tuple saving test loss and test accuracy metrics.
    """

    # Put model in evaluation mode
    model.eval()

    # Initialization of test loss and accuracy
    test_loss, test_acc = 0, 0

    # Disables gradient tracking to save memory and speed up computation during testing
    with torch.inference_mode():
        for batch, (X_test, y_test) in enumerate(dataloader):
            # Send data to target device
            X_test, y_test = X_test.to(device), y_test.to(device)

            # Forward pass
            y_test_pred = model(X_test)

            # Calculate loss and accuracy
            loss = loss_fn(y_test_pred, y_test)
            test_loss += loss

            accuracy = acc_fn(y_test_pred, y_test)
            test_acc += accuracy
            
    # Adjust metrics to get average loss and accuracy per batch 
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)

    return (test_loss, test_acc)

def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          loss_fn: torch.nn.Module,
          acc_fn: torchmetrics.Accuracy,
          optimizer: torch.optim.Optimizer,
          epochs: int,
          device: torch.device)->Dict[str, list]:
    """
    Trains and tests a PyTorch model.

    Passes a target PyTorch models through train_step() and test_step()
    for a number of epochs, training and testing the model
    in the same epoch loop.

    Calculates, prints and stores evaluation metrics throughout.

    Argumentss:
        - model: A PyTorch model to be trained and tested.
        - train_dataloader: DataLoader for the model to be trained on.
        - test_dataloader: DataLoader for the model to be tested on.
        - loss_fn: Pytorch criterion to minimize.
        - acc_fn: Pytorch accuracy metric.
        - epochs: How many epochs to train for.
        - device: Target device to compute on.

    Returns:
        A dictionary of training and testing loss as well as training and
        testing accuracy metrics. Each metric has a value in a list for 
        each epoch.
    """
    # Initialize results dictionary
    results = {"train_loss": [],
        "train_acc": [],
        "test_loss": [],
        "test_acc": []
    }

    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           acc_fn=acc_fn,
                                           optimizer=optimizer,
                                           device=device)

        test_loss, test_acc = test_step(model=model,
                                        dataloader=test_dataloader,
                                        loss_fn=loss_fn,
                                        acc_fn=acc_fn,
                                        device=device)

        # Print out what's happening
        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"test_loss: {test_loss:.4f} | "
            f"test_acc: {test_acc:.4f}"
        )

        # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)

    # Return the filled results at the end of the epochs
    return results

Overwriting src/engine.py


### Making the utility functions

In [44]:
%%writefile src/utils.py
"""
File with utility functions for PyTorch model training and saving.
"""
import torch
from pathlib import Path

def save_model(model: torch.nn.Module,
               target_dir: str,
               model_name: str):
  """
  Saves a PyTorch model to a target directory.

  Arguments:
    - model: Target PyTorch model to save.
    - target_dir: Directory for saving the model to. Default 'model'.
    - model_name: Filename for the saved model. Should include
      either ".pth" or ".pt" as the file extension.
  """
  # Create target directory
  target_dir_path = Path(target_dir)
  target_dir_path.mkdir(parents=True,
                        exist_ok=True)

  # Create model save path
  assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
  model_save_path = target_dir_path / model_name

  # Save the model state_dict()
  print(f"Saving model to: {model_save_path}")
  torch.save(obj=model.state_dict(),
             f=model_save_path)

Overwriting src/utils.py


### Making the CNN train script

In [57]:
%%writefile train_CNN.py
"""
Trains a PyTorch model using device-agnostic code.
Can be controlled via command line arguments for hyperparameter tuning.
"""

import os
import torch
import argparse
import torchmetrics
from torchvision import transforms

# Importing local modules
from src import data_setup, engine, utils, CNN_builder

def main():
    # Setup ArgumentParser
    parser = argparse.ArgumentParser(description="Train a Pytorch model on a choosen dataset.")

    # Add Arguments
    parser.add_argument("--model_name", 
                        type=str, 
                        default="CNN", 
                        help="The filename for the saved model (e.g., 'CNN').")

    parser.add_argument("--train_dir",
                        type=str)

    parser.add_argument("--test_dir", 
                        type=str)
    
    parser.add_argument("--dataset", 
                    type=str, 
                    default=None, 
                    help="Name of the dataset")

    parser.add_argument("--batch_size", 
                        type=int, 
                        default=32, 
                        help="Number of images per batch (default: 32).")

    parser.add_argument("--lr", 
                        type=float, 
                        default=0.001, 
                        help="Learning rate for the optimizer (default: 0.001).")

    parser.add_argument("--num_epochs", 
                        type=int, 
                        default=5, 
                        help="Number of training epochs (default: 5).")

    parser.add_argument("--hidden_units", 
                        type=int, 
                        default=10, 
                        help="Number of hidden units in the neural network (default: 10).")

    args = parser.parse_args()

    # Setup target device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Create transforms
    data_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor()
    ])

    # Create DataLoaders
    train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
        train_dir=args.train_dir,
        test_dir=args.test_dir,
        transform=data_transform,
        batch_size=args.batch_size,
        dataset_type=args.dataset
    )

    # Initialize the model
    model = CNN_builder.CNN(
        input_shape=3,
        hidden_units=args.hidden_units,
        output_shape=len(class_names)
    ).to(device)

    # Construct the path to the model file
    model_save_path = os.path.join("models", args.model_name if args.model_name.endswith(".pth") else args.model_name + ".pth")
    
    if os.path.exists(model_save_path):
        print(f"Loading weights from: {model_save_path}")
        model.load_state_dict(torch.load(model_save_path, map_location=device))
    else:
        print(f"{model_save_path} not found. Training from scratch.")

    # Set loss, optimizer and accuracy function
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    acc_fn = torchmetrics.Accuracy(task="multiclass", num_classes=len(class_names)).to(device)

    # Start training
    results = engine.train(model=model,
                           train_dataloader=train_dataloader,
                           test_dataloader=test_dataloader,
                           loss_fn=loss_fn,
                           optimizer=optimizer,
                           acc_fn=acc_fn,
                           epochs=args.num_epochs,
                           device=device)

    # Save the updated model
    utils.save_model(model=model,
                     target_dir="models",
                     model_name=f"trained_{args.model_name}.pth")

if __name__ == "__main__":
    main()

Overwriting train_CNN.py


### Making the ViT train script

In [58]:
%%writefile train_ViT.py
"""
Trains a PyTorch Vision Transformer (ViT) model using device-agnostic code.
Can be controlled via command line arguments for hyperparameter tuning.
"""

import os
import torch
import argparse
import torchmetrics
from torchvision import transforms

# Importing local modules from the src/ directory
from src import data_setup, engine, utils, ViT_builder

def main():
    # Setup ArgumentParser
    parser = argparse.ArgumentParser(description="Train a Vision Transformer model on a chosen dataset.")

    # File & Path Arguments
    parser.add_argument("--model_name", 
                        type=str, 
                        default="ViT_Model", 
                        help="The filename for the saved model.")
    
    parser.add_argument("--dataset", 
                    type=str, 
                    default=None, 
                    help="Name of the dataset")

    parser.add_argument("--train_dir", 
                        type=str, 
                        default="data/pizza_steak_sushi/train", 
                        help="Directory path for training data.")

    parser.add_argument("--test_dir", 
                        type=str, 
                        default="data/pizza_steak_sushi/test", 
                        help="Directory path for testing data.")

    # General Hyperparameters
    parser.add_argument("--batch_size", 
                        type=int, 
                        default=32, 
                        help="Number of images per batch.")

    parser.add_argument("--lr", 
                        type=float, 
                        default=0.001, 
                        help="Learning rate for the optimizer.")

    parser.add_argument("--num_epochs", 
                        type=int, 
                        default=5, 
                        help="Number of training epochs.")

    # ViT Specific Hyperparameters (Optimized for 64x64 by default)
    parser.add_argument("--img_size", 
                        type=int, 
                        default=64, 
                        help="Input image resolution.")

    parser.add_argument("--patch_size", 
                        type=int, 
                        default=8, 
                        help="Patch size (must divide img_size).")

    parser.add_argument("--embedding_dim", 
                        type=int, 
                        default=128, 
                        help="Hidden dimension size D.")

    parser.add_argument("--mlp_size", 
                        type=int, 
                        default=512, 
                        help="MLP hidden size.")

    parser.add_argument("--num_heads", 
                        type=int, 
                        default=8, 
                        help="Number of attention heads.")

    parser.add_argument("--num_layers", 
                        type=int, 
                        default=8, 
                        help="Number of transformer encoder layers.")

    parser.add_argument("--dropout", 
                        type=float, 
                        default=0.1, 
                        help="Dropout rate for attention and MLP.")

    # Parse the arguments
    args = parser.parse_args()

    # Setup target device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using device: {device}")

    # Create transforms (using the img_size argument)
    data_transform = transforms.Compose([
        transforms.Resize((args.img_size, args.img_size)),
        transforms.ToTensor()
    ])

    # Create DataLoaders with help from data_setup.py
    train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders(
        train_dir=args.train_dir,
        test_dir=args.test_dir,
        transform=data_transform,
        batch_size=args.batch_size,
        dataset_type=args.dataset
    )

    # Initialize ViT model from ViT_builder.py
    model = ViT_builder.ViT(
        img_size=args.img_size,
        in_channels=3,
        patch_size=args.patch_size,
        transformer_layer_num=args.num_layers,
        embedding_dim=args.embedding_dim,
        mlp_size=args.mlp_size,
        num_heads=args.num_heads,
        attn_dropout=args.dropout,
        mlp_dropout=args.dropout,
        embedding_dropout=args.dropout,
        num_classes=len(class_names)
    ).to(device)

    # Set loss, optimizer and accuracy function
    loss_fn = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
    acc_fn = torchmetrics.Accuracy(task="multiclass", num_classes=len(class_names)).to(device)

    # Start training with help from engine.py
    engine.train(model=model,
                 train_dataloader=train_dataloader,
                 test_dataloader=test_dataloader,
                 loss_fn=loss_fn,
                 optimizer=optimizer,
                 acc_fn=acc_fn,
                 epochs=args.num_epochs,
                 device=device)

    # Save the model with help from utils.py
    MODEL_NAME = args.model_name if args.model_name.endswith(".pth") else args.model_name + ".pth"
    utils.save_model(model=model,
                     target_dir="models",
                     model_name=MODEL_NAME)
    print(f"Model saved to: models/{MODEL_NAME}")

if __name__ == "__main__":
    main()

Overwriting train_ViT.py
