In [1]:
import os
import time
import cv2
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.models import resnet18, ResNet18_Weights
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms, models, datasets
from PIL import Image
import mediapipe as mp

KeyboardInterrupt: 

In [None]:
# -----------------------------
# Custom Dataset with MediaPipe Cropping
# -----------------------------
class HandCropDataset(Dataset):
    def __init__(self, root_dir, transform=None, mp_confidence=0.5):
        self.img_folder = datasets.ImageFolder(root=root_dir, transform=None)
        self.classes = self.img_folder.classes
        self.class_to_idx = self.img_folder.class_to_idx
        self.transform = transform
        self.mp_hands = mp.solutions.hands.Hands(
            static_image_mode=True,
            max_num_hands=1,
            min_detection_confidence=mp_confidence
        )

    def __len__(self):
        return len(self.img_folder)

    def __getitem__(self, idx):
        path, label = self.img_folder.samples[idx]
        img_bgr = cv2.imread(path)
        h, w, _ = img_bgr.shape
        img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
        results = self.mp_hands.process(img_rgb)

        if results.multi_hand_landmarks:
            lm = results.multi_hand_landmarks[0].landmark
            xs = [p.x for p in lm]
            ys = [p.y for p in lm]
            x_min = max(int(min(xs) * w) - 20, 0)
            x_max = min(int(max(xs) * w) + 20, w)
            y_min = max(int(min(ys) * h) - 20, 0)
            y_max = min(int(max(ys) * h) + 20, h)
            crop = img_bgr[y_min:y_max, x_min:x_max]
        else:
            side = min(h, w)
            x0 = (w - side) // 2
            y0 = (h - side) // 2
            crop = img_bgr[y0:y0+side, x0:x0+side]

        pil_img = Image.fromarray(cv2.cvtColor(crop, cv2.COLOR_BGR2RGB))
        if self.transform:
            pil_img = self.transform(pil_img)
        return pil_img, label

# -----------------------------
# Paths and Hyperparameters
# -----------------------------
data_dir      = r"C:\Users\myers\Downloads\Train_Alphabet"
batch_size    = 32
learning_rate = 1e-3
epochs        = 10
val_split     = 0.2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# -----------------------------
# Transforms
# -----------------------------
train_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ColorJitter(brightness=0.3, contrast=0.3),
    transforms.RandomAffine(degrees=10, translate=(0.1,0.1)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
])

# -----------------------------
# Dataset and Split
# -----------------------------
full_dataset = HandCropDataset(data_dir, transform=train_transform)
total_size   = len(full_dataset)
val_size     = int(val_split * total_size)
train_size   = total_size - val_size
train_ds, val_ds = random_split(full_dataset, [train_size, val_size])
val_ds.dataset.transform = val_transform
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=0)
print(f"Total images: {total_size}")
print(f"Training images: {train_size}")
print(f"Validation images: {val_size}")

In [None]:
# -----------------------------
# Model Definition (Transfer Learning)
# -----------------------------
weights = ResNet18_Weights.IMAGENET1K_V1
model = resnet18(weights=weights)
for param in model.parameters(): param.requires_grad = False
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, len(full_dataset.classes))
model = model.to(device)

# -----------------------------
# Loss, Optimizer, Scheduler
# -----------------------------
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', patience=2)

# -----------------------------
# Training & Validation Functions
# -----------------------------
def train_one_epoch(model, loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    for batch_idx, (imgs, labels) in enumerate(loader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if batch_idx % 100 == 0:
            print(f"Epoch [{epoch}], Batch [{batch_idx}/{len(loader)}] Loss: {loss.item():.4f}")
    return running_loss / len(loader)


def validate(model, loader, criterion):
    model.eval()
    val_loss = correct = total = 0
    with torch.no_grad():
        for imgs, labels in loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return val_loss / len(loader), 100 * correct / total

# -----------------------------
# Run Training Loop
# -----------------------------
best_val_loss = float('inf')
for epoch in range(1, epochs+1):
    print(f"Starting epoch {epoch}/{epochs}")
    start = time.time()
    train_loss = train_one_epoch(model, train_loader, criterion, optimizer, epoch)
    val_loss, val_acc = validate(model, val_loader, criterion)
    scheduler.step(val_loss)
    print(f"Epoch {epoch}/{epochs} -> train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f}, val_acc: {val_acc:.2f}% | {time.time()-start:.1f}s")
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'class_to_idx': full_dataset.class_to_idx,
            'classes': full_dataset.classes
        }, 'best_asl_resnet_checkpoint.pth')
        print("--> Saved new best model")
print("Training complete.")