In [1]:
import zipfile
import os

zip_file_path = "fairface-img-margin025-trainval.zip"
extract_to_dir = "data"

os.makedirs(extract_to_dir, exist_ok=True)

# with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
#     zip_ref.extractall(extract_to_dir)


In [1]:
import os
import clip
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
from sklearn.metrics import accuracy_score
from tqdm import tqdm

age_prompts = ['3-9', '50-59', '30-39', '20-29', 'more than 70', '40-49', '10-19', '60-69', '0-2']
gender_prompts = ['Male', 'Female']
race_prompts = ['East Asian', 'White', 'Latino_Hispanic', 'Southeast Asian', 'Black', 'Indian', 'Middle Eastern']

csv_file_path = "data/fairface_label_val.csv"
data = pd.read_csv(csv_file_path)

In [4]:
class FairfaceDataset(Dataset):
    def __init__(self, csv_file, image_dir, transforms, target_cols):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transforms = transforms
        self.target_cols = target_cols
        
        # Filter valid rows with existing image files
        self.valid_data = self.data[self.data['file'].apply(
            lambda x: os.path.exists(os.path.join(self.image_dir, x))
        )]
        self.missing_count = len(self.data) - len(self.valid_data)
        if self.missing_count > 0:
            print(f"Warning: {self.missing_count} files are missing and will be skipped.")

    def __len__(self):
        return len(self.valid_data)

    def __getitem__(self, idx):
        row = self.valid_data.iloc[idx]
        image_path = os.path.join(self.image_dir, row['file'])
        try:
            image = Image.open(image_path).convert("RGB")
            image = self.transforms(image)
        except Exception as e:
            return None

        targets = {col: row[col] for col in self.target_cols}
        return image, targets

def custom_collate_fn(batch):
    # Filter out None values
    batch = [item for item in batch if item is not None]
    if len(batch) == 0:
        return None, None  
    images, targets = zip(*batch)  
    return torch.stack(images), {key: [d[key] for d in targets] for key in targets[0]}

def fine_tune_clip(train_csv, val_csv, image_dir, output_dir, epochs=20, batch_size=15, lr=1e-4):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on {device}")
    model, preprocess = clip.load("ViT-B/32", device=device)
    model = model.float()
    
    for param in model.visual.parameters():
        param.requires_grad = False
    
    num_classes = {
        "age": len(age_prompts),
        "gender": len(gender_prompts),
        "race": len(race_prompts)
    }
    classification_heads = {
        col: nn.Linear(model.visual.output_dim, num_classes[col]).to(device)
        for col in num_classes
    }

    # Optimizer
    params = [p for head in classification_heads.values() for p in head.parameters()]
    optimizer = torch.optim.Adam(params, lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Datasets
    train_dataset = FairfaceDataset(train_csv, image_dir, preprocess, target_cols=["age", "gender", "race"])
    val_dataset = FairfaceDataset(val_csv, image_dir, preprocess, target_cols=["age", "gender", "race"])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=custom_collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)
    # Train
    for epoch in range(epochs):
        model.train()
        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = images.to(device)
            losses = []
            for col, head in classification_heads.items():
                labels = torch.tensor([age_prompts.index(label) if col == "age" else
                               gender_prompts.index(label) if col == "gender" else
                               race_prompts.index(label)
                               for label in targets[col]]).to(device)
                head = head
                logits = head(model.visual(images))
                loss = criterion(logits, labels)
                losses.append(loss)

            # Backpropagation
            total_loss = sum(losses)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()


        # Validation
        model.eval()
        val_accuracies = {col: [] for col in classification_heads}
        with torch.no_grad():
            for images, targets in tqdm(val_loader, desc="Validating"):
                images = images.to(device)
                for col, head in classification_heads.items():
                    labels = torch.tensor([age_prompts.index(label) if col == "age" else
                               gender_prompts.index(label) if col == "gender" else
                               race_prompts.index(label)
                               for label in targets[col]]).to(device)

                    head = head
                    logits = head(model.visual(images))
                    preds = logits.argmax(dim=-1)
                    accuracy = accuracy_score(labels.cpu(), preds.cpu())
                    val_accuracies[col].append(accuracy)

        # Validation Result
        print(f"Epoch {epoch+1}/{epochs}")
        for col, accuracies in val_accuracies.items():
            print(f"{col} accuracy: {sum(accuracies)/len(accuracies):.2%}")

    # Save fine-tuned model
    for col, head in classification_heads.items():
        torch.save(head.state_dict(), os.path.join(output_dir, f"{col}_head.pth"))

In [6]:
# Paths
train_csv = "data/fairface_label_train.csv"
val_csv = "data/fairface_label_val.csv"
image_dir = "data"
output_dir = "data/fine_tuned_clip"
os.makedirs(output_dir, exist_ok=True)

fine_tune_clip(train_csv, val_csv, image_dir, output_dir)

In [5]:
def validate_clip(val_csv, image_dir, output_dir, batch_size=32):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Running on {device}")
    model, preprocess = clip.load("ViT-B/32", device=device)
    model = model.float()

    for param in model.visual.parameters():
        param.requires_grad = False

    classification_heads = {
        "age": nn.Linear(model.visual.output_dim, len(age_prompts)).to(device),
        "gender": nn.Linear(model.visual.output_dim, len(gender_prompts)).to(device),
        "race": nn.Linear(model.visual.output_dim, len(race_prompts)).to(device),
    }

    for col, head in classification_heads.items():
        head.load_state_dict(torch.load(os.path.join(output_dir, f"{col}_head.pth")))
        head.eval()

    val_dataset = FairfaceDataset(val_csv, image_dir, preprocess, target_cols=["age", "gender", "race"])
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, collate_fn=custom_collate_fn)

    model.eval()
    val_accuracies = {col: [] for col in classification_heads}
    with torch.no_grad():
        for images, targets in tqdm(val_loader, desc="Validating"):
            if images is None or targets is None:
                continue
            images = images.to(device)
            for col, head in classification_heads.items():
                labels = torch.tensor(
                    [
                        age_prompts.index(label) if col == "age" else
                        gender_prompts.index(label) if col == "gender" else
                        race_prompts.index(label)
                        for label in targets[col]
                    ]
                ).to(device)

                logits = head(model.visual(images))
                preds = logits.argmax(dim=-1)
                accuracy = accuracy_score(labels.cpu(), preds.cpu())
                val_accuracies[col].append(accuracy)

    print("Validation Results:")
    for col, accuracies in val_accuracies.items():
        if accuracies:  
            print(f"{col} accuracy: {sum(accuracies) / len(accuracies):.2%}")
        else:
            print(f"{col}: No valid data to validate.")

validate_clip( val_csv, image_dir, output_dir)

Running on cuda


  head.load_state_dict(torch.load(os.path.join(output_dir, f"{col}_head.pth")))
Validating: 100%|██████████| 343/343 [00:45<00:00,  7.56it/s]

Validation Results:
age accuracy: 59.62%
gender accuracy: 94.70%
race accuracy: 71.49%



