In [1]:
import os
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader
import torch.nn as nn
from torch.optim import Adam
from PIL import ImageFile

# Allow loading of partially corrupted images
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
DATA_DIR = 'fairface-7class-small'  # Your dataset root
BATCH_SIZE = 32
EPOCHS = 10
MODEL_PATH = 'models/nationality_7class_model.pt'
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [3]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [4]:
# 📁 Load dataset with auto-detected classes
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# 🏷️ Number of classes based on folders
NUM_CLASSES = len(train_dataset.classes)
print(f"✅ Detected Classes: {train_dataset.classes}")
# 📁 Load dataset with auto-detected classes
train_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(DATA_DIR, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)

# 🏷️ Number of classes based on folders
NUM_CLASSES = len(train_dataset.classes)
print(f"✅ Detected Classes: {train_dataset.classes}")


✅ Detected Classes: ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White']
✅ Detected Classes: ['Black', 'East Asian', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Southeast Asian', 'White']


In [5]:
from torchvision.models import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.fc = nn.Linear(model.fc.in_features, NUM_CLASSES)
model = model.to(DEVICE)

In [7]:
 #⚙️ Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)


In [8]:
print("🚀 Starting training...")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (inputs, labels) in enumerate(train_loader):
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        if (batch_idx + 1) % 5 == 0:
            print(f"Epoch {epoch+1}, Batch {batch_idx+1}, Loss: {loss.item():.4f}")

    accuracy = 100 * correct / total
    print(f"✅ Epoch [{epoch+1}/{EPOCHS}] - Loss: {running_loss:.4f} - Accuracy: {accuracy:.2f}%")


🚀 Starting training...
Epoch 1, Batch 5, Loss: 1.8138
Epoch 1, Batch 10, Loss: 1.9225
Epoch 1, Batch 15, Loss: 1.9147
Epoch 1, Batch 20, Loss: 1.9497
Epoch 1, Batch 25, Loss: 1.8385
Epoch 1, Batch 30, Loss: 1.5707
Epoch 1, Batch 35, Loss: 1.9778
Epoch 1, Batch 40, Loss: 1.7601
Epoch 1, Batch 45, Loss: 1.6266
Epoch 1, Batch 50, Loss: 1.5234
Epoch 1, Batch 55, Loss: 1.7617
Epoch 1, Batch 60, Loss: 1.6403
Epoch 1, Batch 65, Loss: 1.4392
Epoch 1, Batch 70, Loss: 1.6288
Epoch 1, Batch 75, Loss: 1.6079
Epoch 1, Batch 80, Loss: 1.3807
Epoch 1, Batch 85, Loss: 1.4046
Epoch 1, Batch 90, Loss: 1.4367
Epoch 1, Batch 95, Loss: 1.6998
Epoch 1, Batch 100, Loss: 1.2282
Epoch 1, Batch 105, Loss: 1.4380
Epoch 1, Batch 110, Loss: 1.3115
✅ Epoch [1/10] - Loss: 181.8249 - Accuracy: 34.69%
Epoch 2, Batch 5, Loss: 1.7028
Epoch 2, Batch 10, Loss: 1.4136
Epoch 2, Batch 15, Loss: 1.1938
Epoch 2, Batch 20, Loss: 1.1975
Epoch 2, Batch 25, Loss: 0.9207
Epoch 2, Batch 30, Loss: 1.2512
Epoch 2, Batch 35, Loss: 1.07

In [9]:
os.makedirs("models", exist_ok=True)
torch.save(model, MODEL_PATH)
print(f"\n🎉 Model saved to: {MODEL_PATH}")


🎉 Model saved to: models/nationality_7class_model.pt


In [10]:
MODEL_PATH = 'models/nationality_7class_model.pth'
torch.save(model.state_dict(), MODEL_PATH)
print(f"\n🎉 State dict saved to: {MODEL_PATH}")



🎉 State dict saved to: models/nationality_7class_model.pth
