# **1. Import Libraries**

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

# **2. Define Augmentations**

In [None]:
# Training augmentations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),       # random crop + resize
    transforms.RandomHorizontalFlip(),       # flip horizontally
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), # color variation
    transforms.RandomRotation(15),           # rotate up to 15 degrees
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],  # ImageNet mean
                         [0.229, 0.224, 0.225]) # ImageNet std
])

# Validation/test augmentations (no heavy transforms)
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# **3. Load Dataset**

In [None]:
train_dataset = datasets.ImageFolder("data/train", transform=train_transforms)
val_dataset   = datasets.ImageFolder("data/val", transform=val_transforms)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=32, shuffle=False)

# **4. Load MobileNet-v2**

In [None]:
# Pretrained MobileNet-v2
model = models.mobilenet_v2(pretrained=True)

# Freeze feature extractor (optional for transfer learning)
for param in model.features.parameters():
    param.requires_grad = False

# Replace classifier head
model.classifier[1] = nn.Linear(model.classifier[1].in_features, num_classes)

# **5. Training Setup**

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.classifier.parameters(), lr=0.001)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# **6. Training Loop**

In [None]:
for epoch in range(10):
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")