In [4]:
import matplotlib.pyplot as plt
import torch
import torchvision
from torch import nn
from torchvision import transforms
from torchinfo import summary
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm

from datetime import datetime
import wandb

from PIL import Image
import os
from pathlib import Path
from typing import Tuple, Dict, List

print(f"pytorch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

pytorch version: 2.5.1+cu118
torchvision version: 0.20.1+cpu


In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [6]:
def set_seed(seed: int=42):
    """ Sets random seed for torch operations.

    Args:
        seed (int) : random seed to set
    """
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

In [8]:
data_path = Path("../dataset/")
image_path = data_path / "data"
image_path

WindowsPath('../dataset/data')

In [9]:
train_dir = image_path / "train_dir"
test_dir = image_path / "test_dir"

train_dir, test_dir

(WindowsPath('../dataset/data/train_dir'),
 WindowsPath('../dataset/data/test_dir'))

In [10]:
def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
    """ Finds subdirectories in the given directory and maps class names to indices.

    Args:
        directory (str):  Path to the directory

    Returns:
        Tuple[List[str], Dict[str, int]]
        * A sorted list of class names
        * A dictionary mapping class names to unique indices
    """
    classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())

    if not classes:
        raise FileNotFoundError(f"Couldn't find any classes in {directory}.")

    class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}

    return classes, class_to_idx

In [11]:
classes, class_to_idx = find_classes(train_dir)
classes, class_to_idx

(['daisy', 'dandelion', 'rose', 'sunflower', 'tulip'],
 {'daisy': 0, 'dandelion': 1, 'rose': 2, 'sunflower': 3, 'tulip': 4})

In [12]:
paths = list(Path(train_dir).glob("*/*.jpg"))
paths[:10]

[WindowsPath('../dataset/data/train_dir/daisy/100080576_f52e8ee070_n.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10172379554_b296050f82_n.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10172567486_2748826a8b.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/102841525_bd6628ae3c.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10300722094_28fa978807_n.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/1031799732_e7f4008c03.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10391248763_1d16681106_n.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10437754174_22ec990b77_m.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10437770546_8bb6f7bdd3_m.jpg'),
 WindowsPath('../dataset/data/train_dir/daisy/10437929963_bc13eebe0c.jpg')]

In [13]:
class ImageFolder(Dataset):
    def __init__(self, target_dir: str, transform=None) -> None:
        self.paths = list(Path(target_dir).glob("*/*.jpg"))
        self.transform = transform
        self.classes, self.class_to_idx = find_classes(target_dir)

    def load_image(self, index: int) -> Image.Image:
        image_path = self.paths[index]
        return Image.open(image_path)

    def __len__(self) -> int:
        return len(self.paths)

    def __getitem__(self, index: int):
        img = self.load_image(index)
        class_name = self.paths[index].parent.name
        class_idx = self.class_to_idx[class_name]

        if self.transform:
            return self.transform(img), class_idx
        else:
            return img, class_idx

In [14]:
config = dict(
    epochs=5,
    classes=5,
    batch_size=32,
    learning_rate=1e-3,
 )

wandb.init(project='pretained_vit_classification', entity='aysenurciftci', config=config)

wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: aysenurciftcieee (aysenurciftci). Use `wandb login --relogin` to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011288888886984852, max=1.0…

In [19]:
def create_dataloaders(train_dir: str,
                      test_dir: str,
                      transform: transforms.Compose,
                      batch_size: int):
                        
    """Creates training and testing DataLoaders.

    Args:
        train_dir (str) : path to training directory
        test_dir (str) : path to testing directory
        transform : torchvision transforms to perform on training and testing data.
        batch_size: Number of samples per batch
        num_workers : An integer for number of workers per DataLoader.
   
    Returns:
        A tuple of (train_dataloader, test_dataloader, class_names).
    """

    train_data = ImageFolder(train_dir, transform=transform)
    test_data = ImageFolder(test_dir, transform=transform)

    class_names = train_data.classes

    train_dataloader = DataLoader(train_data,
                                  batch_size=batch_size,
                                  shuffle=True,
                                  pin_memory=True)

    test_dataloader = DataLoader(test_data,
                                 batch_size=batch_size,
                                 shuffle=True,
                                 pin_memory=True)

    return train_dataloader, test_dataloader, class_names

In [21]:
pretrained_vit_weights = torchvision.models.ViT_B_32_Weights.DEFAULT

pretrained_vit = torchvision.models.vit_b_32(weights=pretrained_vit_weights).to(device)

for parameter in pretrained_vit.parameters():
    parameter.requires_grad=False

#set_seed()
pretrained_vit.heads = nn.Linear(in_features=768,
                                 out_features=len(classes)).to(device)

Downloading: "https://download.pytorch.org/models/vit_b_32-d86f8d99.pth" to C:\Users\aysen/.cache\torch\hub\checkpoints\vit_b_32-d86f8d99.pth

00%|███████████████████████████████████████████████████████████████████████████████| 337M/337M [01:12<00:00, 4.87MB/s]

In [22]:
pretrained_vit_transforms = pretrained_vit_weights.transforms()
print(pretrained_vit_transforms)

ImageClassification(
    crop_size=[224]
    resize_size=[256]
    mean=[0.485, 0.456, 0.406]
    std=[0.229, 0.224, 0.225]
    interpolation=InterpolationMode.BILINEAR
)


In [23]:
train_dataloader, test_dataloader, class_names = create_dataloaders(train_dir=train_dir,
                                                                   test_dir=test_dir,
                                                                   transform=pretrained_vit_transforms,
                                                                   batch_size=32)

In [24]:
summary(model=pretrained_vit,
        input_size=(32, 3, 224, 224),
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]) 

Layer (type (var_name))                                      Input Shape          Output Shape         Param #              Trainable
VisionTransformer (VisionTransformer)                        [32, 3, 224, 224]    [32, 5]              768                  Partial
├─Conv2d (conv_proj)                                         [32, 3, 224, 224]    [32, 768, 7, 7]      (2,360,064)          False
├─Encoder (encoder)                                          [32, 50, 768]        [32, 50, 768]        38,400               False
│    └─Dropout (dropout)                                     [32, 50, 768]        [32, 50, 768]        --                   --
│    └─Sequential (layers)                                   [32, 50, 768]        [32, 50, 768]        --                   False
│    │    └─EncoderBlock (encoder_layer_0)                   [32, 50, 768]        [32, 50, 768]        (7,087,872)          False
│    │    └─EncoderBlock (encoder_layer_1)                   [32, 50, 768]        [32, 

In [25]:
def train_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              optimizer: torch.optim.Optimizer,
              device: torch.device):

    """ Perform a single training step for a pytorch model.

    Args:
        model : the neural network model to train.
        dataloader : dataloader providing the training data in batches.
        loss_fn : the loss function to evaluate the model's predictions.
        optimizer : the optimizer to update the model's parameters.
        device: target device
    Returns:
        tuple: 
            *train_loss (float) :  The average loss over the training set.
            *train_acc (float) : The average accuracy over the training set.
    """

    #put model in train mode
    model.train()

    train_loss, train_acc = 0, 0

    for batch, (X, y) in enumerate(dataloader):
        #send data to target device
        X, y = X.to(device), y.to(device)

        #forward pass
        y_pred = model(X)

        # calculate loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()

        #optimizer zero grad
        optimizer.zero_grad()

        #loss backward
        loss.backward()

        #optimizer step
        optimizer.step()

        #calculate and acc metric
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)

    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)

    return train_loss, train_acc

In [26]:
def test_step(model: torch.nn.Module,
              dataloader: torch.utils.data.DataLoader,
              loss_fn : torch.nn.Module,
              device: torch.device):

    """ Performs a single evaluation step for a PyTorch model.

    Args:
        model : the neural network model to evaluate.
        dataloader : DataLoader providing the test/validation data in batches.
        loss_fn : The loss function to evaluate the model's predictions.
        device: the target device.

    Returns: 
        tuple : 
            * test_loss (float) : The average loss over the test set.
            * test_acc (float) : The average accuracy over the test set.
    """

    
    #put model in eval mode
    model.eval()

    test_loss, test_acc = 0, 0

    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            #send data to target device
            X, y = X.to(device), y.to(device)

            #forward pass
            test_pred_logits = model(X)

            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()

            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))

        test_loss = test_loss / len(dataloader)
        test_acc = test_acc / len(dataloader)

        return test_loss, test_acc

In [28]:
def train(model: torch.nn.Module,
          train_dataloader: torch.utils.data.DataLoader,
          test_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device):
    """Trains and tests a PyTorch model.
    """

    results = {"train_loss" : [],
              "train_acc" : [],
              "test_loss" : [],
              "test_acc": []
              }

    #loop through training and testing steps for a number of epochs
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn, 
                                           optimizer=optimizer,
                                           device=device)
        test_loss, test_acc = test_step(model=model,
                                       dataloader=test_dataloader,
                                       loss_fn=loss_fn,
                                       device=device)

        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
          )


        wandb.log({
            "epoch": epoch + 1,
            "train_loss": train_loss,
            "train_acc": train_acc,
            "test_loss": test_loss,
            "test_acc": test_acc
        })

          # Update results dictionary
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)


    wandb.save("model.onnx")

    return results   

In [30]:
torch.manual_seed(42)

set_seed()

optimizer = torch.optim.Adam(params=pretrained_vit.parameters(),
                             lr=1e-3)

loss_fn = torch.nn.CrossEntropyLoss()

pretrained_vit_results = train(model=pretrained_vit,
                              train_dataloader=train_dataloader,
                              test_dataloader=test_dataloader,
                              optimizer=optimizer,
                              loss_fn=loss_fn,
                              epochs=5,
                              device=device,
                              )

  0%|          | 0/5 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 0.5769 | train_acc: 0.8271 | test_loss: 0.3159 | test_acc: 0.9074
Epoch: 2 | train_loss: 0.3101 | train_acc: 0.9111 | test_loss: 0.2564 | test_acc: 0.9190
Epoch: 3 | train_loss: 0.2508 | train_acc: 0.9287 | test_loss: 0.2330 | test_acc: 0.9190
Epoch: 4 | train_loss: 0.2167 | train_acc: 0.9398 | test_loss: 0.2203 | test_acc: 0.9213
Epoch: 5 | train_loss: 0.1886 | train_acc: 0.9470 | test_loss: 0.2176 | test_acc: 0.9282
