# CelebA Multi-Attribute Training with ConvNeXt Tiny

This notebook trains a ConvNeXt Tiny model on 40,000 random images from the CelebA dataset for facial attribute prediction.

In [1]:
# Import Required Libraries
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
import timm
from PIL import Image
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    roc_auc_score, accuracy_score, f1_score, 
    precision_score, recall_score, confusion_matrix
)
from tqdm import tqdm
import warnings
warnings.filterwarnings('ignore')

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

Using device: cuda
GPU: NVIDIA GeForce RTX 5050 Laptop GPU
Memory: 7.96 GB


## Configuration

In [2]:
# Configuration
BACKBONE = 'focalnet_tiny_srf'
SAMPLE_SIZE = 40000  # Number of images to sample from CelebA

# Training hyperparameters
BATCH_SIZE = 16
NUM_EPOCHS = 3
LEARNING_RATE = 0.0001
NUM_WORKERS = 0
IMAGE_SIZE = 224
VAL_SPLIT = 0.2
RANDOM_SEED = 42

# Paths
DATA_DIR = './Data/CelebA'
CSV_PATH = os.path.join(DATA_DIR, 'list_attr_celeba.csv')
IMAGES_DIR = os.path.join(DATA_DIR, 'img_align_celeba')
MODEL_SAVE_DIR = './Models'
RESULTS_DIR = './Results'

# Create directories
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
os.makedirs(RESULTS_DIR, exist_ok=True)

print(f"Selected Backbone: {BACKBONE}")
print(f"Sample Size: {SAMPLE_SIZE} images")
print(f"Training Configuration:")
print(f"  - Batch Size: {BATCH_SIZE}")
print(f"  - Epochs: {NUM_EPOCHS}")
print(f"  - Learning Rate: {LEARNING_RATE}")
print(f"  - Image Size: {IMAGE_SIZE}x{IMAGE_SIZE}")

Selected Backbone: focalnet_tiny_srf
Sample Size: 40000 images
Training Configuration:
  - Batch Size: 16
  - Epochs: 3
  - Learning Rate: 0.0001
  - Image Size: 224x224


## Dataset Loading and Preprocessing

In [3]:
# Custom Dataset Class
class CelebADataset(Dataset):
    def __init__(self, df, images_dir, transform=None):
        self.df = df.reset_index(drop=True)
        self.images_dir = images_dir
        self.transform = transform
        
        # Get attribute columns (all except image_id)
        self.attributes = [col for col in df.columns if col != 'image_id']
        self.num_attributes = len(self.attributes)
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_name = row['image_id']
        img_path = os.path.join(self.images_dir, img_name)
        
        # Load image
        try:
            image = Image.open(img_path).convert('RGB')
        except Exception as e:
            print(f"Error loading image {img_name}: {e}")
            # Return a black image as fallback
            image = Image.new('RGB', (IMAGE_SIZE, IMAGE_SIZE), (0, 0, 0))
        
        if self.transform:
            image = self.transform(image)
        
        # Get labels for all attributes
        # Convert -1 to 0 for binary classification (CelebA uses -1 for negative)
        label_values = row[self.attributes].values.astype(np.float32)
        label_values = np.where(label_values == -1, 0, label_values)
        labels = torch.from_numpy(label_values)
        
        return image, labels

# Load dataset
print("Loading CelebA dataset...")
df = pd.read_csv(CSV_PATH)
print(f"Original dataset shape: {df.shape}")
print(f"Number of attributes: {len(df.columns) - 1}")

# Get attribute columns (all except image_id)
attribute_cols = [col for col in df.columns if col != 'image_id']

# Sample 40,000 images randomly
if len(df) > SAMPLE_SIZE:
    df = df.sample(n=SAMPLE_SIZE, random_state=RANDOM_SEED).reset_index(drop=True)
    print(f"\nSampled {SAMPLE_SIZE} images from the dataset")
else:
    print(f"\nDataset has {len(df)} images (less than {SAMPLE_SIZE}), using entire dataset")

print(f"Final dataset shape: {df.shape}")
print(f"\nAttributes ({len(attribute_cols)}):")
print(', '.join(attribute_cols))

# Split dataset
train_df, val_df = train_test_split(df, test_size=VAL_SPLIT, random_state=RANDOM_SEED)
print(f"\nTrain set: {len(train_df)} samples")
print(f"Validation set: {len(val_df)} samples")

Loading CelebA dataset...
Original dataset shape: (202599, 41)
Number of attributes: 40

Sampled 40000 images from the dataset
Final dataset shape: (40000, 41)

Attributes (40):
5_o_Clock_Shadow, Arched_Eyebrows, Attractive, Bags_Under_Eyes, Bald, Bangs, Big_Lips, Big_Nose, Black_Hair, Blond_Hair, Blurry, Brown_Hair, Bushy_Eyebrows, Chubby, Double_Chin, Eyeglasses, Goatee, Gray_Hair, Heavy_Makeup, High_Cheekbones, Male, Mouth_Slightly_Open, Mustache, Narrow_Eyes, No_Beard, Oval_Face, Pale_Skin, Pointy_Nose, Receding_Hairline, Rosy_Cheeks, Sideburns, Smiling, Straight_Hair, Wavy_Hair, Wearing_Earrings, Wearing_Hat, Wearing_Lipstick, Wearing_Necklace, Wearing_Necktie, Young

Train set: 32000 samples
Validation set: 8000 samples
Original dataset shape: (202599, 41)
Number of attributes: 40

Sampled 40000 images from the dataset
Final dataset shape: (40000, 41)

Attributes (40):
5_o_Clock_Shadow, Arched_Eyebrows, Attractive, Bags_Under_Eyes, Bald, Bangs, Big_Lips, Big_Nose, Black_Hair, Blo

In [4]:
# Data Transforms (No Augmentation - Simple Resize and Normalize)
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create datasets
train_dataset = CelebADataset(train_df, IMAGES_DIR, transform=train_transform)
val_dataset = CelebADataset(val_df, IMAGES_DIR, transform=val_transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, 
                          shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, 
                        shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print(f"Train batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

# Get number of attributes
num_attributes = train_dataset.num_attributes
print(f"\nTraining for {num_attributes} attributes")

Train batches: 2000
Validation batches: 500

Training for 40 attributes


## Model Architecture

In [5]:
# Multi-Attribute Classification Model
class MultiAttributeModel(nn.Module):
    def __init__(self, backbone_name, num_attributes, pretrained=True):
        super(MultiAttributeModel, self).__init__()
        self.backbone_name = backbone_name
        
        # Load ConvNeXt Tiny backbone
        self.backbone = timm.create_model(backbone_name, pretrained=pretrained)
        in_features = self.backbone.head.fc.in_features
        self.backbone.head.fc = nn.Identity()
        
        # Classification head for multi-attribute prediction
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_attributes)
        )
        
    def forward(self, x):
        features = self.backbone(x)
        output = self.classifier(features)
        return output

# Create model
print(f"\nCreating {BACKBONE} model...")
model = MultiAttributeModel(BACKBONE, num_attributes, pretrained=True)
model = model.to(device)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


Creating focalnet_tiny_srf model...


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

Total parameters: 28,072,364
Trainable parameters: 28,072,364


## Training Setup

In [6]:
# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross Entropy for multi-label classification
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3
)

# Training history
history = {
    'train_loss': [],
    'val_loss': [],
    'val_acc': []
}

print("Training setup complete!")

Training setup complete!


## Training Loop

In [7]:
# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(dataloader, desc='Training')
    for images, labels in pbar:
        images = images.to(device)
        labels = labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    return running_loss / len(dataloader)

# Validation function
def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        pbar = tqdm(dataloader, desc='Validation')
        for images, labels in pbar:
            images = images.to(device)
            labels = labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            
            # Store predictions and labels
            preds = torch.sigmoid(outputs) > 0.5
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels.cpu().numpy())
            
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    # Calculate accuracy
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)
    accuracy = (all_preds == all_labels).mean()
    
    return running_loss / len(dataloader), accuracy

print("Training functions defined!")

Training functions defined!


In [8]:
# Main training loop
print(f"\n{'='*60}")
print(f"Starting Training: {BACKBONE} on CelebA")
print(f"{'='*60}\n")

best_val_loss = float('inf')
best_epoch = 0

for epoch in range(NUM_EPOCHS):
    print(f"\nEpoch {epoch+1}/{NUM_EPOCHS}")
    print("-" * 60)
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc = validate_epoch(model, val_loader, criterion, device)
    
    # Update learning rate
    scheduler.step(val_loss)
    
    # Store history
    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    # Print metrics
    print(f"\nEpoch {epoch+1} Summary:")
    print(f"  Train Loss: {train_loss:.4f}")
    print(f"  Val Loss: {val_loss:.4f}")
    print(f"  Val Accuracy: {val_acc:.4f}")
    print(f"  Learning Rate: {optimizer.param_groups[0]['lr']:.6f}")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_epoch = epoch + 1
        model_path = os.path.join(MODEL_SAVE_DIR, f'celeba_{BACKBONE}_best.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, model_path)
        print(f"  ✓ Best model saved! (Val Loss: {val_loss:.4f})")

print(f"\n{'='*60}")
print(f"Training Complete!")
print(f"Best Epoch: {best_epoch}")
print(f"Best Val Loss: {best_val_loss:.4f}")
print(f"{'='*60}\n")


Starting Training: focalnet_tiny_srf on CelebA


Epoch 1/3
------------------------------------------------------------


Training: 100%|██████████| 2000/2000 [05:51<00:00,  5.70it/s, loss=0.2089]
Training: 100%|██████████| 2000/2000 [05:51<00:00,  5.70it/s, loss=0.2089]
Validation: 100%|██████████| 500/500 [00:49<00:00, 10.19it/s, loss=0.2296]




Epoch 1 Summary:
  Train Loss: 0.2562
  Val Loss: 0.2151
  Val Accuracy: 0.9055
  Learning Rate: 0.000100
  ✓ Best model saved! (Val Loss: 0.2151)

Epoch 2/3
------------------------------------------------------------
  ✓ Best model saved! (Val Loss: 0.2151)

Epoch 2/3
------------------------------------------------------------


Training: 100%|██████████| 2000/2000 [05:41<00:00,  5.86it/s, loss=0.2134]
Training: 100%|██████████| 2000/2000 [05:41<00:00,  5.86it/s, loss=0.2134]
Validation: 100%|██████████| 500/500 [00:47<00:00, 10.55it/s, loss=0.2221]




Epoch 2 Summary:
  Train Loss: 0.2072
  Val Loss: 0.2010
  Val Accuracy: 0.9123
  Learning Rate: 0.000100
  ✓ Best model saved! (Val Loss: 0.2010)

Epoch 3/3
------------------------------------------------------------
  ✓ Best model saved! (Val Loss: 0.2010)

Epoch 3/3
------------------------------------------------------------


Training: 100%|██████████| 2000/2000 [05:35<00:00,  5.96it/s, loss=0.2019]
Training: 100%|██████████| 2000/2000 [05:35<00:00,  5.96it/s, loss=0.2019]
Validation: 100%|██████████| 500/500 [00:47<00:00, 10.51it/s, loss=0.2091]




Epoch 3 Summary:
  Train Loss: 0.1895
  Val Loss: 0.2000
  Val Accuracy: 0.9126
  Learning Rate: 0.000100
  ✓ Best model saved! (Val Loss: 0.2000)

Training Complete!
Best Epoch: 3
Best Val Loss: 0.2000

  ✓ Best model saved! (Val Loss: 0.2000)

Training Complete!
Best Epoch: 3
Best Val Loss: 0.2000



## Load Best Model and Evaluate

In [9]:
# Load best model for evaluation
print("Loading best model for evaluation...")
model_path = os.path.join(MODEL_SAVE_DIR, f'celeba_{BACKBONE}_best.pth')

checkpoint = torch.load(model_path, map_location=device, weights_only=False)

# Support both plain state_dict and wrapped dict formats
state_dict = checkpoint
if isinstance(checkpoint, dict) and "model_state_dict" in checkpoint:
    state_dict = checkpoint["model_state_dict"]

model.load_state_dict(state_dict)
model.eval()
print("Model ready for evaluation.")

Loading best model for evaluation...
Model ready for evaluation.


## Per-Attribute Metrics Calculation

In [10]:
# Get predictions on validation set
print("\nCalculating per-attribute metrics...")
all_preds = []
all_probs = []
all_labels = []

with torch.no_grad():
    for images, labels in tqdm(val_loader, desc='Predicting'):
        images = images.to(device)
        outputs = model(images)
        
        probs = torch.sigmoid(outputs)
        preds = probs > 0.5
        
        all_probs.append(probs.cpu().numpy())
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

# Concatenate all batches
all_probs = np.vstack(all_probs)
all_preds = np.vstack(all_preds)
all_labels = np.vstack(all_labels)

print(f"Predictions shape: {all_preds.shape}")
print(f"Labels shape: {all_labels.shape}")


Calculating per-attribute metrics...


Predicting: 100%|██████████| 500/500 [00:47<00:00, 10.48it/s]

Predictions shape: (8000, 40)
Labels shape: (8000, 40)





In [11]:
# Calculate metrics for each attribute
attribute_names = train_dataset.attributes
results = []

print("\nCalculating metrics for each attribute...")
for i, attr_name in enumerate(tqdm(attribute_names, desc='Processing attributes')):
    y_true = all_labels[:, i]
    y_pred = all_preds[:, i]
    y_prob = all_probs[:, i]
    
    # Skip if only one class present
    if len(np.unique(y_true)) < 2:
        print(f"Warning: Attribute '{attr_name}' has only one class in validation set")
        continue
    
    # Calculate metrics
    try:
        auc = roc_auc_score(y_true, y_prob)
    except:
        auc = 0.0
    
    accuracy = accuracy_score(y_true, y_pred)
    
    # F1 scores
    macro_f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    micro_f1 = f1_score(y_true, y_pred, average='micro', zero_division=0)
    
    # Precision and Recall
    precision = precision_score(y_true, y_pred, average='binary', zero_division=0)
    recall = recall_score(y_true, y_pred, average='binary', zero_division=0)
    
    results.append({
        'attribute': attr_name,
        'auc': auc,
        'accuracy': accuracy,
        'macro_f1': macro_f1,
        'micro_f1': micro_f1,
        'precision': precision,
        'recall': recall
    })

# Create results DataFrame
results_df = pd.DataFrame(results)

# Calculate mean metrics
mean_metrics = results_df[['auc', 'accuracy', 'macro_f1', 'micro_f1', 'precision', 'recall']].mean()

print("\n" + "="*60)
print("OVERALL METRICS (Mean across all attributes)")
print("="*60)
for metric, value in mean_metrics.items():
    print(f"{metric.upper():20s}: {value:.4f}")
print("="*60)


Calculating metrics for each attribute...


Processing attributes: 100%|██████████| 40/40 [00:00<00:00, 140.51it/s]


OVERALL METRICS (Mean across all attributes)
AUC                 : 0.9396
ACCURACY            : 0.9126
MACRO_F1            : 0.8156
MICRO_F1            : 0.9126
PRECISION           : 0.7743
RECALL              : 0.6585





## Save Results

In [12]:
# Save results CSV
csv_filename = f'celeba_{BACKBONE}_metrics.csv'
csv_path = os.path.join(RESULTS_DIR, csv_filename)
results_df.to_csv(csv_path, index=False)

print(f"\n✓ Results saved to: {csv_path}")
print(f"\nResults preview:")
print(results_df.head(10))

# Sort by F1 score and display top/bottom performing attributes
print("\n" + "="*60)
print("TOP 10 ATTRIBUTES (by Macro F1)")
print("="*60)
top_attrs = results_df.nlargest(10, 'macro_f1')[['attribute', 'macro_f1', 'accuracy', 'auc']]
print(top_attrs.to_string(index=False))

print("\n" + "="*60)
print("BOTTOM 10 ATTRIBUTES (by Macro F1)")
print("="*60)
bottom_attrs = results_df.nsmallest(10, 'macro_f1')[['attribute', 'macro_f1', 'accuracy', 'auc']]
print(bottom_attrs.to_string(index=False))


✓ Results saved to: ./Results\celeba_focalnet_tiny_srf_metrics.csv

Results preview:
          attribute       auc  accuracy  macro_f1  micro_f1  precision  \
0  5_o_Clock_Shadow  0.959635  0.934000  0.821219  0.934000   0.703145   
1   Arched_Eyebrows  0.910581  0.808375  0.783968  0.808375   0.596464   
2        Attractive  0.910648  0.818750  0.818658  0.818750   0.817541   
3   Bags_Under_Eyes  0.885462  0.842500  0.718132  0.842500   0.684261   
4              Bald  0.995795  0.990250  0.887338  0.990250   0.715026   
5             Bangs  0.986903  0.956250  0.911332  0.956250   0.889091   
6          Big_Lips  0.753168  0.776250  0.626852  0.776250   0.556202   
7          Big_Nose  0.883942  0.835625  0.750160  0.835625   0.724188   
8        Black_Hair  0.956536  0.905375  0.862059  0.905375   0.805134   
9        Blond_Hair  0.982927  0.956375  0.914956  0.956375   0.859518   

     recall  
0  0.656874  
1  0.881063  
2  0.828002  
3  0.433698  
4  0.857143  
5  0.810945  
6

## Training Summary

In [13]:
# Final summary
print("\n" + "="*60)
print("TRAINING COMPLETE - SUMMARY")
print("="*60)
print(f"Dataset: CelebA")
print(f"Backbone Model: {BACKBONE}")
print(f"Total Parameters: {total_params:,}")
print(f"Sample Size: {SAMPLE_SIZE} images")
print(f"Training Samples: {len(train_dataset)}")
print(f"Validation Samples: {len(val_dataset)}")
print(f"Number of Attributes: {num_attributes}")
print(f"Number of Epochs: {NUM_EPOCHS}")
print(f"Best Epoch: {best_epoch}")
print(f"\nModel saved at: {model_path}")
print(f"Results saved at: {csv_path}")
print("\n" + "="*60)
print("MEAN METRICS")
print("="*60)
for metric, value in mean_metrics.items():
    print(f"{metric.upper():20s}: {value:.4f}")
print("="*60)

# Display attribute distribution
print("\n" + "="*60)
print("ATTRIBUTE STATISTICS")
print("="*60)
print(f"Total attributes evaluated: {len(results_df)}")
print(f"\nMetrics Range:")
print(f"  AUC: {results_df['auc'].min():.4f} - {results_df['auc'].max():.4f}")
print(f"  Accuracy: {results_df['accuracy'].min():.4f} - {results_df['accuracy'].max():.4f}")
print(f"  Macro F1: {results_df['macro_f1'].min():.4f} - {results_df['macro_f1'].max():.4f}")
print("="*60)


TRAINING COMPLETE - SUMMARY
Dataset: CelebA
Backbone Model: focalnet_tiny_srf
Total Parameters: 28,072,364
Sample Size: 40000 images
Training Samples: 32000
Validation Samples: 8000
Number of Attributes: 40
Number of Epochs: 3
Best Epoch: 3

Model saved at: ./Models\celeba_focalnet_tiny_srf_best.pth
Results saved at: ./Results\celeba_focalnet_tiny_srf_metrics.csv

MEAN METRICS
AUC                 : 0.9396
ACCURACY            : 0.9126
MACRO_F1            : 0.8156
MICRO_F1            : 0.9126
PRECISION           : 0.7743
RECALL              : 0.6585

ATTRIBUTE STATISTICS
Total attributes evaluated: 40

Metrics Range:
  AUC: 0.7532 - 0.9983
  Accuracy: 0.7579 - 0.9955
  Macro F1: 0.6269 - 0.9812
