In [None]:
import os
import clip
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import pandas as pd
from sklearn.metrics import accuracy_score
from tqdm import tqdm

age_prompts = ['age 3-9', 'age 50-59', 'age 30-39', 'age 20-29', 'age more than 70', 'age 40-49', 'age 10-19', 'age 60-69', 'age 0-2']
gender_prompts = ['Male', 'Female']
race_prompts = ['East Asian', 'White', 'Latino_Hispanic', 'Southeast Asian', 'Black', 'Indian', 'Middle Eastern']

csv_file_path = "../fairface/fairface_label_val.csv"
data = pd.read_csv(csv_file_path)

In [None]:
class FairfaceDataset(Dataset):
    def __init__(self, csv_file, image_dir, transforms, target_cols):
        self.data = pd.read_csv(csv_file)
        self.image_dir = image_dir
        self.transforms = transforms
        self.target_cols = target_cols
        
        # Filter valid rows with existing image files
        self.valid_data = self.data[self.data['file'].apply(
            lambda x: os.path.exists(os.path.join(self.image_dir, x))
        )]
        self.missing_count = len(self.data) - len(self.valid_data)
        if self.missing_count > 0:
            print(f"Warning: {self.missing_count} files are missing and will be skipped.")

    def __len__(self):
        return len(self.valid_data)

    def __getitem__(self, idx):
        row = self.valid_data.iloc[idx]
        image_path = os.path.join(self.image_dir, row['file'])
        image = Image.open(image_path).convert("RGB")
        image = self.transforms(image)

        targets = {col: row[col] for col in self.target_cols}
        return image, targets


def fine_tune_clip(train_csv, val_csv, image_dir, output_dir, epochs=2, batch_size=10, lr=1e-4):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, preprocess = clip.load("ViT-B/32", device=device)
    
    for param in model.visual.parameters():
        param.requires_grad = False
    
    num_classes = {
        "age": len(age_prompts),
        "gender": len(gender_prompts),
        "race": len(race_prompts)
    }
    classification_heads = {
        col: nn.Linear(model.visual.output_dim, num_classes[col]).to(device)
        for col in num_classes
    }

    # Optimizer
    params = [p for head in classification_heads.values() for p in head.parameters()]
    optimizer = torch.optim.Adam(params, lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Datasets
    train_dataset = FairfaceDataset(train_csv, image_dir, preprocess, target_cols=["age", "gender", "race"])
    val_dataset = FairfaceDataset(val_csv, image_dir, preprocess, target_cols=["age", "gender", "race"])
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

    # Train
    for epoch in range(epochs):
        model.train()
        for images, targets in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images = images.to(device)
            losses = []
            for col, head in classification_heads.items():
                labels = torch.tensor([age_prompts.index(label) if col == "age" else
                               gender_prompts.index(label) if col == "gender" else
                               race_prompts.index(label)
                               for label in targets[col]]).to(device)
                logits = head(model.visual(images))
                loss = criterion(logits, labels)
                losses.append(loss)

            # Backpropagation
            total_loss = sum(losses)
            optimizer.zero_grad()
            total_loss.backward()
            optimizer.step()


        # Validation
        model.eval()
        val_accuracies = {col: [] for col in classification_heads}
        with torch.no_grad():
            for images, targets in tqdm(val_loader, desc="Validating"):
                images = images.to(device)
                for col, head in classification_heads.items():
                    labels = torch.tensor([age_prompts.index(label) if col == "age" else
                               gender_prompts.index(label) if col == "gender" else
                               race_prompts.index(label)
                               for label in targets[col]]).to(device)
                    logits = head(model.visual(images))
                    preds = logits.argmax(dim=-1)
                    accuracy = accuracy_score(labels.cpu(), preds.cpu())
                    val_accuracies[col].append(accuracy)

        # Validation Result
        print(f"Epoch {epoch+1}/{epochs}")
        for col, accuracies in val_accuracies.items():
            print(f"{col} accuracy: {sum(accuracies)/len(accuracies):.2%}")

    # Save fine-tuned model
    for col, head in classification_heads.items():
        torch.save(head.state_dict(), os.path.join(output_dir, f"{col}_head.pth"))

In [None]:
# Paths
train_csv = "../fairface/fairface_label_train.csv"
val_csv = "../fairface/fairface_label_val.csv"
image_dir = "../fairface/fairface-img-margin025-trainval"
output_dir = "../fairface/fine_tuned_clip"
os.makedirs(output_dir, exist_ok=True)

fine_tune_clip(train_csv, val_csv, image_dir, output_dir)



Epoch 1/2:   0%|          | 0/10 [00:00<?, ?it/s]

{'age': ['50-59', '20-29', '10-19', '20-29', '0-2', '40-49', '30-39', '60-69', '50-59', '30-39'], 'gender': ['Female', 'Female', 'Female', 'Female', 'Female', 'Male', 'Male', 'Male', 'Male', 'Female'], 'race': ['White', 'Indian', 'Latino_Hispanic', 'East Asian', 'Black', 'Latino_Hispanic', 'Latino_Hispanic', 'East Asian', 'East Asian', 'White']}


Epoch 1/2:  10%|█         | 1/10 [00:20<03:05, 20.64s/it]

{'age': ['30-39', '20-29', '20-29', '10-19', '10-19', '20-29', '30-39', '30-39', '20-29', '3-9'], 'gender': ['Female', 'Female', 'Female', 'Female', 'Male', 'Male', 'Male', 'Female', 'Female', 'Female'], 'race': ['Black', 'Black', 'Black', 'Indian', 'Indian', 'Black', 'East Asian', 'Black', 'Indian', 'Black']}


Epoch 1/2:  20%|██        | 2/10 [00:36<02:22, 17.82s/it]

{'age': ['20-29', '20-29', '40-49', '10-19', '30-39', '20-29', '3-9', '10-19', '20-29', '20-29'], 'gender': ['Female', 'Female', 'Male', 'Male', 'Male', 'Male', 'Female', 'Female', 'Female', 'Male'], 'race': ['Latino_Hispanic', 'Southeast Asian', 'Indian', 'Indian', 'Indian', 'White', 'East Asian', 'Latino_Hispanic', 'Middle Eastern', 'Black']}


Epoch 1/2:  30%|███       | 3/10 [00:57<02:14, 19.20s/it]

{'age': ['30-39', '60-69', '20-29', '30-39', '20-29', '30-39', '10-19', '20-29', '60-69', '50-59'], 'gender': ['Female', 'Female', 'Male', 'Male', 'Male', 'Male', 'Male', 'Female', 'Female', 'Male'], 'race': ['Indian', 'Indian', 'Southeast Asian', 'Latino_Hispanic', 'Southeast Asian', 'Indian', 'Indian', 'Indian', 'Middle Eastern', 'Southeast Asian']}


Epoch 1/2:  40%|████      | 4/10 [01:28<02:24, 24.12s/it]

{'age': ['60-69', '3-9', '50-59', '10-19', '20-29', '20-29', '20-29', '50-59', '30-39', '20-29'], 'gender': ['Male', 'Female', 'Male', 'Female', 'Female', 'Female', 'Male', 'Male', 'Female', 'Female'], 'race': ['Latino_Hispanic', 'Latino_Hispanic', 'White', 'White', 'White', 'Southeast Asian', 'Indian', 'Southeast Asian', 'White', 'White']}


Epoch 1/2:  50%|█████     | 5/10 [01:52<01:59, 23.82s/it]

{'age': ['20-29', '40-49', '60-69', '20-29', '20-29', '40-49', '10-19', '20-29', '0-2', '50-59'], 'gender': ['Male', 'Female', 'Male', 'Female', 'Male', 'Female', 'Female', 'Male', 'Female', 'Male'], 'race': ['Latino_Hispanic', 'Black', 'White', 'Southeast Asian', 'East Asian', 'Latino_Hispanic', 'Southeast Asian', 'Black', 'East Asian', 'Indian']}


Epoch 1/2:  60%|██████    | 6/10 [02:15<01:34, 23.50s/it]

{'age': ['30-39', '50-59', '30-39', '10-19', '20-29', '40-49', '40-49', '30-39', '60-69', '10-19'], 'gender': ['Male', 'Female', 'Male', 'Female', 'Female', 'Male', 'Male', 'Female', 'Female', 'Male'], 'race': ['Latino_Hispanic', 'Black', 'White', 'Latino_Hispanic', 'Black', 'White', 'Latino_Hispanic', 'Indian', 'Indian', 'White']}


Epoch 1/2:  70%|███████   | 7/10 [02:39<01:10, 23.61s/it]

{'age': ['20-29', '20-29', '20-29', '30-39', '50-59', '40-49', '30-39', '40-49', '30-39', '30-39'], 'gender': ['Male', 'Female', 'Male', 'Female', 'Female', 'Male', 'Male', 'Male', 'Female', 'Female'], 'race': ['East Asian', 'Latino_Hispanic', 'White', 'White', 'Black', 'Indian', 'Indian', 'Middle Eastern', 'Latino_Hispanic', 'White']}


Epoch 1/2:  80%|████████  | 8/10 [03:07<00:50, 25.12s/it]

{'age': ['3-9', '30-39', '30-39', '50-59', '10-19', '30-39', '10-19', '3-9', '40-49', '30-39'], 'gender': ['Female', 'Female', 'Male', 'Male', 'Female', 'Female', 'Female', 'Male', 'Male', 'Male'], 'race': ['Middle Eastern', 'White', 'Southeast Asian', 'Southeast Asian', 'Southeast Asian', 'East Asian', 'Black', 'East Asian', 'Middle Eastern', 'Latino_Hispanic']}


Epoch 1/2:  90%|█████████ | 9/10 [03:29<00:24, 24.14s/it]

{'age': ['30-39', '50-59', '30-39', '10-19', '10-19', '30-39', '40-49'], 'gender': ['Female', 'Male', 'Male', 'Female', 'Female', 'Male', 'Male'], 'race': ['Middle Eastern', 'Middle Eastern', 'White', 'East Asian', 'Black', 'White', 'Southeast Asian']}


Epoch 1/2: 100%|██████████| 10/10 [03:49<00:00, 22.95s/it]
Validating: 100%|██████████| 10/10 [00:52<00:00,  5.25s/it]


Epoch 1/2
age accuracy: 25.86%
gender accuracy: 52.14%
race accuracy: 13.86%


Epoch 2/2:   0%|          | 0/10 [00:00<?, ?it/s]

{'age': ['40-49', '60-69', '3-9', '10-19', '40-49', '50-59', '20-29', '50-59', '30-39', '20-29'], 'gender': ['Male', 'Female', 'Male', 'Female', 'Male', 'Male', 'Female', 'Male', 'Female', 'Female'], 'race': ['Middle Eastern', 'Indian', 'East Asian', 'Black', 'Indian', 'White', 'Southeast Asian', 'Southeast Asian', 'Indian', 'Middle Eastern']}


Epoch 2/2:  10%|█         | 1/10 [00:05<00:51,  5.69s/it]

{'age': ['10-19', '50-59', '10-19', '30-39', '0-2', '20-29', '20-29', '60-69', '30-39', '20-29'], 'gender': ['Female', 'Male', 'Male', 'Female', 'Female', 'Female', 'Female', 'Female', 'Female', 'Male'], 'race': ['Black', 'Indian', 'Indian', 'White', 'East Asian', 'Latino_Hispanic', 'Indian', 'Middle Eastern', 'White', 'Latino_Hispanic']}


Epoch 2/2:  20%|██        | 2/10 [00:10<00:42,  5.37s/it]

{'age': ['40-49', '30-39', '50-59', '20-29', '20-29', '30-39', '20-29', '10-19', '3-9', '10-19'], 'gender': ['Female', 'Female', 'Male', 'Female', 'Female', 'Male', 'Female', 'Female', 'Female', 'Female'], 'race': ['Black', 'White', 'Southeast Asian', 'White', 'Southeast Asian', 'White', 'Black', 'East Asian', 'Middle Eastern', 'Latino_Hispanic']}


Epoch 2/2:  30%|███       | 3/10 [00:16<00:38,  5.50s/it]

{'age': ['50-59', '30-39', '3-9', '60-69', '50-59', '20-29', '60-69', '10-19', '30-39', '40-49'], 'gender': ['Male', 'Male', 'Female', 'Male', 'Male', 'Male', 'Female', 'Female', 'Male', 'Male'], 'race': ['Southeast Asian', 'East Asian', 'Latino_Hispanic', 'East Asian', 'Middle Eastern', 'White', 'Indian', 'Southeast Asian', 'Latino_Hispanic', 'White']}


Epoch 2/2:  40%|████      | 4/10 [00:22<00:35,  5.85s/it]

{'age': ['20-29', '20-29', '20-29', '40-49', '30-39', '60-69', '20-29', '30-39', '40-49', '20-29'], 'gender': ['Male', 'Female', 'Male', 'Male', 'Male', 'Male', 'Female', 'Female', 'Male', 'Female'], 'race': ['White', 'Black', 'Southeast Asian', 'Latino_Hispanic', 'Latino_Hispanic', 'White', 'Black', 'Indian', 'Indian', 'Indian']}


Epoch 2/2:  50%|█████     | 5/10 [00:28<00:29,  5.89s/it]

{'age': ['10-19', '30-39', '40-49', '10-19', '30-39', '20-29', '60-69', '3-9', '30-39', '30-39'], 'gender': ['Male', 'Female', 'Female', 'Female', 'Female', 'Female', 'Male', 'Female', 'Female', 'Male'], 'race': ['Indian', 'Latino_Hispanic', 'Latino_Hispanic', 'Latino_Hispanic', 'White', 'East Asian', 'Latino_Hispanic', 'Black', 'White', 'Latino_Hispanic']}


Epoch 2/2:  60%|██████    | 6/10 [00:35<00:24,  6.22s/it]

{'age': ['30-39', '30-39', '20-29', '40-49', '40-49', '50-59', '20-29', '30-39', '20-29', '20-29'], 'gender': ['Female', 'Male', 'Female', 'Male', 'Male', 'Female', 'Female', 'Male', 'Female', 'Male'], 'race': ['Middle Eastern', 'Indian', 'Latino_Hispanic', 'Middle Eastern', 'Latino_Hispanic', 'Black', 'Indian', 'Indian', 'White', 'Southeast Asian']}


Epoch 2/2:  70%|███████   | 7/10 [00:41<00:18,  6.05s/it]

{'age': ['10-19', '10-19', '10-19', '0-2', '20-29', '20-29', '10-19', '30-39', '20-29', '50-59'], 'gender': ['Female', 'Female', 'Male', 'Female', 'Male', 'Male', 'Female', 'Male', 'Male', 'Female'], 'race': ['Latino_Hispanic', 'Indian', 'Indian', 'Black', 'Black', 'Indian', 'Southeast Asian', 'Southeast Asian', 'East Asian', 'White']}


Epoch 2/2:  80%|████████  | 8/10 [00:46<00:11,  5.74s/it]

{'age': ['30-39', '30-39', '30-39', '30-39', '20-29', '30-39', '40-49', '10-19', '3-9', '20-29'], 'gender': ['Male', 'Male', 'Female', 'Male', 'Female', 'Female', 'Male', 'Male', 'Female', 'Male'], 'race': ['White', 'Latino_Hispanic', 'East Asian', 'Indian', 'Southeast Asian', 'Black', 'Southeast Asian', 'White', 'East Asian', 'East Asian']}


Epoch 2/2:  90%|█████████ | 9/10 [00:52<00:05,  5.83s/it]

{'age': ['30-39', '50-59', '30-39', '50-59', '10-19', '20-29', '20-29'], 'gender': ['Male', 'Female', 'Female', 'Male', 'Female', 'Male', 'Male'], 'race': ['White', 'Black', 'Black', 'East Asian', 'White', 'Black', 'Black']}


Epoch 2/2: 100%|██████████| 10/10 [00:55<00:00,  5.51s/it]
Validating: 100%|██████████| 10/10 [01:00<00:00,  6.07s/it]

Epoch 2/2
age accuracy: 28.86%
gender accuracy: 56.14%
race accuracy: 18.86%





'/Users/qiaochufeng'