In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
from sklearn.model_selection import train_test_split
import os
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# Configuration
DATA_DIR = './asl_data/asl_alphabet_train/asl_alphabet_train' 
MODEL_SAVE_PATH = 'asl_new_resnet50.pth'
BATCH_SIZE = 32
NUM_EPOCHS = 6  
LEARNING_RATE = 0.001
IMG_SIZE = 224
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Using device: {DEVICE}")

Using device: cuda


In [7]:
TRAIN_DIR = Path(DATA_DIR)
class_names = sorted([p.name for p in TRAIN_DIR.iterdir() if p.is_dir()])
class_to_idx = {name: i for i, name in enumerate(class_names)}

print(f"Found {len(class_names)} classes.")

all_image_paths = []
all_labels = []

print("Scanning for images...")
for class_name in class_names:
    class_dir = TRAIN_DIR / class_name
    images = list(class_dir.rglob("*.[jJ][pP][gG]")) + \
             list(class_dir.rglob("*.[jJ][pP][eE][gG]")) + \
             list(class_dir.rglob("*.png"))
             
    for img_path in images:
        all_image_paths.append(str(img_path))
        all_labels.append(class_to_idx[class_name])

print(f"Total images found: {len(all_image_paths)}")

# Stratified split to ensure equal representation in Val set
train_paths, val_paths, train_labels, val_labels = train_test_split(
    all_image_paths, all_labels, test_size=0.2, random_state=42, stratify=all_labels
)

Found 29 classes.
Scanning for images...
Total images found: 87000


In [8]:
class SimpleASLDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self): return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Added Normalization to match ResNet standards
train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

train_dataset = SimpleASLDataset(train_paths, train_labels, transform=train_transform)
val_dataset = SimpleASLDataset(val_paths, val_labels, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)}")

Train: 69600 | Val: 17400


In [9]:
model = models.resnet50(weights='IMAGENET1K_V1')

num_ftrs = model.fc.in_features
model.fc = nn.Sequential(
    nn.Dropout(0.5),
    nn.Linear(num_ftrs, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, len(class_names)) 
)

model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

In [10]:
best_acc = 0.0
train_losses, val_losses = [], []
train_accs, val_accs = [], []

print(f"Starting training on {DEVICE}...")

try:
    for epoch in range(NUM_EPOCHS):
        model.train()
        running_loss, correct, total = 0.0, 0, 0
        loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS}", leave=True) 
        
        for images, labels in loop:
            images, labels = images.to(DEVICE), labels.to(DEVICE)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            loop.set_postfix(loss=loss.item())

        epoch_val_loss, correct, total = 0.0, 0, 0
        model.eval()
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(DEVICE), labels.to(DEVICE)
                outputs = model(images)
                loss = criterion(outputs, labels)
                epoch_val_loss += loss.item()
                _, predicted = torch.max(outputs, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        val_acc = 100 * correct / total
        val_loss = epoch_val_loss / len(val_loader)
        
        scheduler.step(val_loss)
        print(f"Val Acc: {val_acc:.2f}% | Val Loss: {val_loss:.4f}")

        if val_acc > best_acc:
            best_acc = val_acc
            # CRITICAL: Saving weights AND metadata
            checkpoint = {
                'model_state_dict': model.state_dict(),
                'class_names': class_names,
                'class_to_idx': class_to_idx,
                'accuracy': val_acc
            }
            torch.save(checkpoint, MODEL_SAVE_PATH)
            print(f"New best model saved! ({best_acc:.2f}%)")

except KeyboardInterrupt:
    print("\nTraining interrupted.")

Starting training on cuda...


Epoch 1/6:   0%|          | 0/2175 [00:00<?, ?it/s]

Epoch 1/6: 100%|██████████| 2175/2175 [10:10<00:00,  3.56it/s, loss=0.291] 


Val Acc: 96.92% | Val Loss: 0.0943
New best model saved! (96.92%)


Epoch 2/6: 100%|██████████| 2175/2175 [10:21<00:00,  3.50it/s, loss=0.18]   


Val Acc: 98.88% | Val Loss: 0.0322
New best model saved! (98.88%)


Epoch 3/6: 100%|██████████| 2175/2175 [10:23<00:00,  3.49it/s, loss=0.00173] 


Val Acc: 99.32% | Val Loss: 0.0229
New best model saved! (99.32%)


Epoch 4/6: 100%|██████████| 2175/2175 [10:23<00:00,  3.49it/s, loss=0.0958]  


Val Acc: 99.56% | Val Loss: 0.0152
New best model saved! (99.56%)


Epoch 5/6: 100%|██████████| 2175/2175 [10:21<00:00,  3.50it/s, loss=0.135]   


Val Acc: 99.14% | Val Loss: 0.0328


Epoch 6/6: 100%|██████████| 2175/2175 [10:19<00:00,  3.51it/s, loss=0.0072]  


Val Acc: 99.96% | Val Loss: 0.0020
New best model saved! (99.96%)
