### Training The Model

In [4]:
# Import necessary libraries
import os
import torch
import torchvision
from torch import nn
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from tqdm.auto import tqdm
from typing import Dict, List, Tuple

# Define constants
TRAIN_DIR = "train/train/"
TEST_DIR = "test/test/"
device = "cuda" if torch.cuda.is_available() else "cpu"
batch_size = 16
learning_rate = 1e-3

# Set random seeds for reproducibility
def set_seeds(seed: int=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seeds()

# Load pre-trained ViT model
pretrained_vit_weights = torchvision.models.ViT_B_16_Weights.DEFAULT 
pretrained_vit = torchvision.models.vit_b_16(weights=pretrained_vit_weights).to(device)

# Freeze pre-trained layers
for parameter in pretrained_vit.parameters():
    parameter.requires_grad = False

# Modify head for classification
pretrained_vit.heads = nn.Linear(in_features=768,out_features=100).to(device)

# Display model summary
from torchinfo import summary
summary(model=pretrained_vit, 
        input_size=(batch_size, 3, 224, 224), 
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"])

# Define data transformations and create data loaders
pretrained_vit_transforms = pretrained_vit_weights.transforms()

def create_dataloaders(
    train_dir: str, 
    transform: transforms.Compose, 
    batch_size: int, 
    validation_split: float = 0.1,
    num_workers: int = os.cpu_count()
):

    dataset = datasets.ImageFolder(train_dir, transform=transform)

    # Split dataset into training and validation sets
    num_train = int(len(dataset) * (1 - validation_split))
    num_val = len(dataset) - num_train
    train_data, val_data = random_split(dataset, [num_train, num_val])

    # Create data loaders for training and validation sets
    train_dataloader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True
    )

    val_dataloader = DataLoader(
        val_data,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )

    return train_dataloader, val_dataloader

# Create data loaders
train_dataloader_pretrained, val_dataloader_pretrained = create_dataloaders(
    train_dir=TRAIN_DIR, 
    transform=pretrained_vit_transforms, 
    batch_size=batch_size
)

# Define training and evaluation functions
def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    
    model.train()
    train_loss, train_acc = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        optimizer.zero_grad()
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        loss.backward()
        optimizer.step()
        train_loss += loss.item() 
        y_pred_class = torch.argmax(y_pred, dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc

def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    model.eval() 
    test_loss, test_acc = 0, 0
    with torch.no_grad():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device), y.to(device)
            test_pred_logits = model(X)
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()
            test_pred_labels = torch.argmax(test_pred_logits, dim=1)
            test_acc += (test_pred_labels == y).sum().item()/len(test_pred_labels)
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc

def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader,
          val_dataloader: torch.utils.data.DataLoader,
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
    
    results = {"train_loss": [],
               "train_acc": [],
               "val_loss": [],
               "val_acc": []}

    model.to(device)
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                           dataloader=train_dataloader,
                                           loss_fn=loss_fn,
                                           optimizer=optimizer,
                                           device=device)
        
        val_loss, val_acc = test_step(model=model,
                                      dataloader=val_dataloader,
                                      loss_fn=loss_fn,
                                      device=device)

        print(
            f"Epoch: {epoch+1} | "
            f"train_loss: {train_loss:.4f} | "
            f"train_acc: {train_acc:.4f} | "
            f"val_loss: {val_loss:.4f} | "
            f"val_acc: {val_acc:.4f} | "
        )

        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)

    return results

optimizer = torch.optim.Adam(params=pretrained_vit.parameters(), lr=learning_rate)
loss_fn = torch.nn.CrossEntropyLoss()

# Train the model
pretrained_vit_results = train(model=pretrained_vit,
                               train_dataloader=train_dataloader_pretrained,
                               val_dataloader=val_dataloader_pretrained,
                               optimizer=optimizer,
                               loss_fn=loss_fn,
                               epochs=10,
                               device=device)



  0%|          | 0/10 [00:00<?, ?it/s]

Epoch: 1 | train_loss: 3.6588 | train_acc: 0.1732 | val_loss: 2.8407 | val_acc: 0.2232 | 
Epoch: 2 | train_loss: 2.0031 | train_acc: 0.5208 | val_loss: 2.3145 | val_acc: 0.3036 | 
Epoch: 3 | train_loss: 1.3841 | train_acc: 0.7149 | val_loss: 2.1063 | val_acc: 0.4107 | 
Epoch: 4 | train_loss: 1.0419 | train_acc: 0.8037 | val_loss: 1.8986 | val_acc: 0.5089 | 
Epoch: 5 | train_loss: 0.8266 | train_acc: 0.8783 | val_loss: 1.7827 | val_acc: 0.5357 | 
Epoch: 6 | train_loss: 0.6620 | train_acc: 0.9232 | val_loss: 1.7146 | val_acc: 0.5357 | 
Epoch: 7 | train_loss: 0.5548 | train_acc: 0.9452 | val_loss: 1.6949 | val_acc: 0.5446 | 
Epoch: 8 | train_loss: 0.4702 | train_acc: 0.9682 | val_loss: 1.7071 | val_acc: 0.5625 | 
Epoch: 9 | train_loss: 0.4061 | train_acc: 0.9638 | val_loss: 1.6538 | val_acc: 0.5179 | 
Epoch: 10 | train_loss: 0.3508 | train_acc: 0.9803 | val_loss: 1.6188 | val_acc: 0.5268 | 


### Storing the Weights

In [5]:
# Define the path to save the model weights
model_weights_path = "pretrained_vit_weights.pth"

# Save the trained model weights
torch.save(pretrained_vit.state_dict(), model_weights_path)

### Sample Code to for prediction using the stored, Weights

In [1]:
#This is a map, between the labels, I used for different classes and the actuall class names in the train/train directoryu
my_dict = {
    0: 0,
    1: 1,
    10: 2,
    11: 3,
    12: 4,
    13: 5,
    14: 6,
    15: 7,
    16: 8,
    17: 9,
    18: 10,
    19: 11,
    2: 12,
    20: 13,
    21: 14,
    22: 15,
    23: 16,
    24: 17,
    25: 18,
    26: 19,
    27: 20,
    28: 21,
    29: 22,
    3: 23,
    30: 24,
    31: 25,
    32: 26,
    33: 27,
    34: 28,
    35: 29,
    36: 30,
    37: 31,
    38: 32,
    39: 33,
    4: 34,
    40: 35,
    41: 36,
    42: 37,
    43: 38,
    44: 39,
    45: 40,
    46: 41,
    47: 42,
    48: 43,
    49: 44,
    5: 45,
    50: 46,
    51: 47,
    52: 48,
    53: 49,
    54: 50,
    55: 51,
    56: 52,
    57: 53,
    58: 54,
    59: 55,
    6: 56,
    60: 57,
    61: 58,
    62: 59,
    63: 60,
    64: 61,
    65: 62,
    66: 63,
    67: 64,
    68: 65,
    69: 66,
    7: 67,
    70: 68,
    71: 69,
    72: 70,
    73: 71,
    74: 72,
    75: 73,
    76: 74,
    77: 75,
    78: 76,
    79: 77,
    8: 78,
    80: 79,
    81: 80,
    82: 81,
    83: 82,
    84: 83,
    85: 84,
    86: 85,
    87: 86,
    88: 87,
    89: 88,
    9: 89,
    90: 90,
    91: 91,
    92: 92,
    93: 93,
    94: 94,
    95: 95,
    96: 96,
    97: 97,
    98: 98,
    99: 99
}

# Exchange keys and values using dictionary comprehension
exchanged_dict = {v: k for k, v in my_dict.items()}

In [None]:
from PIL import Image
#Replace this with the actuall path
model_weights_path = "pretrained_vit_weights.pth"
test_image_path = "test\\test\\0.jpg"
def preprocess_image_error(image_path):
    image = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = transform(image).unsqueeze(0)  # Add batch dimension
    return image.to(device)
def preprocess_image(image_path):
    try:
        transform = pretrained_vit_transforms
        image = Image.open(image_path)
        image = transform(image).unsqueeze(0)  # Add batch dimension
        return image.to(device)
    except:
        return preprocess_image_error(image_path)

def predict_image_class(image_path, model):
    image = preprocess_image(image_path)
    model.eval()
    with torch.no_grad():
        output = model(image)
    probabilities = torch.softmax(output, dim=1)[0]
    predicted_class_index = torch.argmax(probabilities).item()
    return predicted_class_index, probabilities

model = pretrained_vit


model.load_state_dict(torch.load(model_weights_path))


predicted_class_index_map, probabilities = predict_image_class(test_image_path, model)

predicted_class_index = exchanged_dict[predicted_class_index_map]
print("Predicted Class:", predicted_class_index)
print("Probabilities:", probabilities)
