# Training a ResNet-18 model on grayscale images using Auslan Dataset from Kaggle.

# 1. Importing Required Libraries

In this section, we import the necessary libraries for:
- **Loading datasets**: `torchvision.datasets`, `DataLoader`
- **Data augmentation and preprocessing**: `torchvision.transforms`
- **Building and training the model**: `torch`, `torchvision.models`, `optim`, `nn`
- **Progress tracking**: `tqdm` for displaying progress bars during training


In [3]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torchvision import datasets, transforms, models
from torchvision.models import ResNet18_Weights
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from sklearn.metrics import classification_report, confusion_matrix

# Weights and Biases initialisation

Initialize W&B. Log hyperparameters to W&B

In [8]:

# Set your W&B API key (replace with your actual API key from your W&B account)
os.environ["WANDB_API_KEY"] = "40605c764bdfbd3ad65c3f127ed7d9dbdd4ce87f"

# Login again with the new API key
wandb.login(relogin=True)
wandb.init(project="auslan-handsign-classification",

# Log hyperparameters to W&B
     config = {
    "learning_rate": 0.001,
    "epochs": 25,
    "batch_size": 64,
    "architecture": "ResNet18",
    "dataset": "Auslan Hand Signs"
    }
)

# 2. Data Augmentation and Preprocessing

In this section, we define the transformations that will be applied to the images before they are fed into the model:
- **Grayscale conversion**: Convert the images to grayscale.
- **Random augmentations**: Random resized cropping, horizontal flipping, and rotation are applied to the training images to help the model generalize better.
- **Random Erasing**: This helps the model learn to handle occlusions and missing parts in the input images.
- **Normalization**: Normalizing the pixel values between 0 and 1 is important for stabilizing the model's learning process.


In [10]:
# Data augmentation and preprocessing for the training and validation sets

train_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    # transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),  # Convert to tensor
    transforms.RandomErasing(p=0.2),  # Randomly erase part of the tensor
    transforms.Normalize([0.5], [0.5])  # Normalize for grayscale (1 channel)
])

val_test_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # Convert images to grayscale
    transforms.Resize((224, 224)),
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize([0.5], [0.5])  # Normalize for grayscale (1 channel)
])


# 3. Loading Datasets and Creating DataLoaders

- **ImageFolder**: Automatically assigns labels based on the subfolder names, which represent the class names.
- **DataLoader**: Loads batches of data from the `train` and `val` folders. It shuffles the training data to introduce randomness in batches, while validation data is loaded in a fixed order.


In [11]:
train_dir = r"C:\Users\zed20\Documents\Auslan_dataset\dataset_split\train"
val_dir = r"C:\Users\zed20\Documents\Auslan_dataset\dataset_split\val"

train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_test_transforms)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)


# 4. Model Setup (ResNet-18 with Grayscale Input Modification)

- We load a pre-trained ResNet-18 model that has been trained on ImageNet.
- Modify the **first convolutional layer** to accept 1-channel (grayscale) images instead of 3-channel (RGB) images.
- Modify the **fully connected (FC) layer** to output the correct number of classes (36 in this case).


In [12]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load ResNet-18 model with ImageNet pre-trained weights
model = models.resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)

# Modify the first layer to accept grayscale input
model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)

# Modify the FC layer for 36 output classes (26 letters + 10 digits)
num_classes = 36
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(model.fc.in_features, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, num_classes)
)

model = model.to(device)


# 5. Defining Loss Function and Optimizer

- **CrossEntropyLoss**: This is used for multi-class classification tasks, where the model predicts one out of multiple classes.
- **AdamW Optimizer**: Used for weight decay and stability in optimization.
- **OneCycleLR Scheduler**: This dynamically adjusts the learning rate throughout the training process to help the model converge faster.


In [13]:
criterion = nn.CrossEntropyLoss()

optimizer = optim.AdamW(model.parameters(), lr=0.001)

scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01,
                                                steps_per_epoch=len(train_loader),
                                                epochs=25)


# 6. Validation Function

This function performs validation on the model after each training epoch. It:
- Disables gradient calculation using `torch.no_grad()`.
- Calculates the validation loss and accuracy.


In [14]:
def validate_model(model, val_loader, criterion):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    y_true, y_pred = [], []

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

            # Store for classification report
            y_true.extend(labels.cpu().numpy())
            y_pred.extend(predicted.cpu().numpy())

    val_loss /= len(val_loader.dataset)
    val_acc = 100 * correct / total

    return val_loss, val_acc, y_true, y_pred


# 7. Training Loop with Early Stopping and Logging confusion matrix

The training loop:
- Trains the model over a number of epochs.
- After each epoch, it performs validation.
- Uses early stopping if the validation accuracy does not improve for a set number of epochs (`patience`).
- Saves the model with the best validation accuracy.


In [15]:
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, patience=5):
    best_val_acc = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        # Training loop
        for inputs, labels in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs} - Training"):
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        train_acc = 100 * correct / total
        val_loss, val_acc, y_true, y_pred = validate_model(model, val_loader, criterion)

        # Log metrics to W&B
        wandb.log({
            "train_loss": running_loss / len(train_loader.dataset),
            "val_loss": val_loss,
            "train_accuracy": train_acc,
            "val_accuracy": val_acc,
            "epoch": epoch + 1
        })

        # Log classification metrics (F1-score, precision, recall)
        classification_report_dict = classification_report(y_true, y_pred, output_dict=True)
        wandb.log({
            "precision": classification_report_dict["macro avg"]["precision"],
            "recall": classification_report_dict["macro avg"]["recall"],
            "f1-score": classification_report_dict["macro avg"]["f1-score"],
            "accuracy": classification_report_dict["accuracy"]
        })

        # Log confusion matrix (optional)
        wandb.log({"confusion_matrix": wandb.plot.confusion_matrix(probs=None,
                                                                  y_true=y_true,
                                                                  preds=y_pred,
                                                                  class_names=[str(i) for i in range(36)])})

        # Step the scheduler
        scheduler.step()

        # Early stopping check
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_resnet18_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping triggered")
                break

        print(f"Epoch {epoch + 1}, Train Loss: {running_loss / len(train_loader.dataset):.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    print("Training complete. Best Val Acc: {:.2f}%".format(best_val_acc))


# 8. Training the Model

This is the final step where we train the model using the `train_model` function defined earlier. The model is saved as `resnet18_handsign_final.pth` once training is complete.


In [16]:
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=25, patience=5)

# Save the final trained model
torch.save(model.state_dict(), 'resnet18_handsign_final.pth')
print("Model saved to resnet18_handsign_final.pth")

# Finalize the W&B run
wandb.finish()

Epoch 1/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:32<00:00,  2.86it/s]


Epoch 1, Train Loss: 2.3058, Train Acc: 30.74%
Val Loss: 0.9557, Val Acc: 68.00%


Epoch 2/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 2, Train Loss: 1.2762, Train Acc: 59.92%
Val Loss: 0.5954, Val Acc: 79.23%


Epoch 3/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:23<00:00,  2.96it/s]


Epoch 3, Train Loss: 0.9749, Train Acc: 69.49%
Val Loss: 0.5422, Val Acc: 81.65%


Epoch 4/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 4, Train Loss: 0.8316, Train Acc: 74.11%
Val Loss: 0.3846, Val Acc: 87.43%


Epoch 5/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 5, Train Loss: 0.7390, Train Acc: 77.45%
Val Loss: 0.2846, Val Acc: 90.30%


Epoch 6/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 6, Train Loss: 0.6539, Train Acc: 79.85%
Val Loss: 0.3135, Val Acc: 89.75%


Epoch 7/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:23<00:00,  2.96it/s]


Epoch 7, Train Loss: 0.6141, Train Acc: 81.02%
Val Loss: 0.2736, Val Acc: 90.87%


Epoch 8/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:23<00:00,  2.96it/s]


Epoch 8, Train Loss: 0.5696, Train Acc: 82.59%
Val Loss: 0.2297, Val Acc: 92.39%


Epoch 9/25 - Training: 100%|█████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 9, Train Loss: 0.5412, Train Acc: 83.41%
Val Loss: 0.2214, Val Acc: 92.76%


Epoch 10/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:23<00:00,  2.96it/s]


Epoch 10, Train Loss: 0.5147, Train Acc: 84.19%
Val Loss: 0.1920, Val Acc: 93.71%


Epoch 11/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 11, Train Loss: 0.4914, Train Acc: 85.01%
Val Loss: 0.1350, Val Acc: 95.73%


Epoch 12/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 12, Train Loss: 0.4724, Train Acc: 85.63%
Val Loss: 0.1359, Val Acc: 95.35%


Epoch 13/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 13, Train Loss: 0.4572, Train Acc: 85.97%
Val Loss: 0.1578, Val Acc: 94.90%


Epoch 14/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.98it/s]


Epoch 14, Train Loss: 0.4342, Train Acc: 86.64%
Val Loss: 0.1115, Val Acc: 96.26%


Epoch 15/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 15, Train Loss: 0.4233, Train Acc: 87.13%
Val Loss: 0.1266, Val Acc: 95.90%


Epoch 16/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Epoch 16, Train Loss: 0.4137, Train Acc: 87.31%
Val Loss: 0.1419, Val Acc: 95.35%


Epoch 17/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:24<00:00,  2.95it/s]


Epoch 17, Train Loss: 0.3980, Train Acc: 87.90%
Val Loss: 0.1155, Val Acc: 96.16%


Epoch 18/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:21<00:00,  2.98it/s]


Epoch 18, Train Loss: 0.3829, Train Acc: 88.15%
Val Loss: 0.1119, Val Acc: 95.99%


Epoch 19/25 - Training: 100%|████████████████████████████████████████████████████████| 780/780 [04:22<00:00,  2.97it/s]


Early stopping triggered
Training complete. Best Val Acc: 96.26%
Model saved to resnet18_handsign_final.pth


wandb: ERROR Control-C detected -- Run data was not synced


KeyboardInterrupt: 