In [1]:
import os
import pandas as pd
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder

import torch
import torch.nn as nn
import torchvision.models as models
import torch.optim as optim

In [2]:
image_dir = "/kaggle/input/celeba-attribute/img_align_celeba/img_align_celeba"
attr_path = "/kaggle/input/celeba-attribute/list_attr_celeba.txt"
split_path = "/kaggle/input/celeba-attribute/list_eval_partition.txt"

# Transforms for ResNet
transform = transforms.Compose([
    transforms.CenterCrop(178),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                     std=[0.229, 0.224, 0.225])
    # transforms.Normalize([0.5]*3, [0.5]*3)
])

In [3]:
attr_df = pd.read_csv(attr_path, sep='\s+', skiprows=1)
# Make the index a column
attr_df.reset_index(inplace=True)
attr_df.rename(columns={"index": "image_id"}, inplace=True)

# Map -1 to 0 (for binary classification)
attr_df.replace(-1, 0, inplace=True)

# Load evaluation split
eval_df = pd.read_csv("/kaggle/input/celeba-attribute/list_eval_partition.txt", 
                      sep='\s+', header=None, names=["image_id", "split"])
attr_df = attr_df.merge(eval_df, on="image_id")
attr_df

Unnamed: 0,image_id,5_o_Clock_Shadow,Arched_Eyebrows,Attractive,Bags_Under_Eyes,Bald,Bangs,Big_Lips,Big_Nose,Black_Hair,...,Smiling,Straight_Hair,Wavy_Hair,Wearing_Earrings,Wearing_Hat,Wearing_Lipstick,Wearing_Necklace,Wearing_Necktie,Young,split
0,000001.jpg,0,1,1,0,0,0,0,0,0,...,1,1,0,1,0,1,0,0,1,0
1,000002.jpg,0,0,0,1,0,0,0,1,0,...,1,0,0,0,0,0,0,0,1,0
2,000003.jpg,0,0,0,0,0,0,1,0,0,...,0,0,1,0,0,0,0,0,1,0
3,000004.jpg,0,0,1,0,0,0,0,0,0,...,0,1,0,1,0,1,1,0,1,0
4,000005.jpg,0,1,1,0,0,0,1,0,0,...,0,0,0,0,0,1,0,0,1,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
202594,202595.jpg,0,0,1,0,0,0,1,0,0,...,0,0,0,0,0,1,0,0,1,2
202595,202596.jpg,0,0,0,0,0,1,1,0,0,...,1,1,0,0,0,0,0,0,1,2
202596,202597.jpg,0,0,0,0,0,0,0,0,1,...,1,0,0,0,0,0,0,0,1,2
202597,202598.jpg,0,1,1,0,0,0,1,0,1,...,1,0,1,1,0,1,0,0,1,2


In [4]:
class CelebABinaryHairDataset(Dataset):
    def __init__(self, df, image_root, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_root, row['image_id'])
        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(row['Blond_Hair'], dtype=torch.float32)
        gender = torch.tensor(row['Male'], dtype=torch.int64)
        return image, label, gender

In [5]:
class SyntheticHairDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_paths = []
        self.labels = []
        self.genders = []
        self.transform = transform

        subgroup_map = {
            "generated_images_male_blond": (1, 1),
            "generated_female_blond": (1, 0),
            "generated_images_male_non_blond": (0, 1),
            "generated_female_non_blond": (0, 0)
        }

        for subgroup in os.listdir(image_dir):
            subgroup_path = os.path.join(image_dir, subgroup)
            if not os.path.isdir(subgroup_path) or subgroup not in subgroup_map:
                continue

            label, gender = subgroup_map[subgroup]

            for img_name in os.listdir(subgroup_path):
                if img_name.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(subgroup_path, img_name))
                    self.labels.append(label)
                    self.genders.append(gender)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")
        label = self.labels[idx]
        gender = self.genders[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, gender

In [6]:
train_df = attr_df[attr_df['split'] == 0]
val_df = attr_df[attr_df['split'] == 1]
test_df  = attr_df[attr_df['split'] == 2]

# Dataset instances
train_set = CelebABinaryHairDataset(train_df, image_dir, transform)
val_set = CelebABinaryHairDataset(val_df, image_dir, transform)
synthetic_dataset = SyntheticHairDataset("/kaggle/input/final-new/FINAL", transform)
test_set  = CelebABinaryHairDataset(test_df, image_dir, transform)

# Dataloaders
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=0)
val_loader = DataLoader(val_set, batch_size=64, shuffle=True, num_workers=0)
synthetic_loader = DataLoader(synthetic_dataset, batch_size=32, shuffle=True, num_workers=0)
test_loader  = DataLoader(test_set, batch_size=32, shuffle=False, num_workers=0)

In [7]:
# Training function with tqdm
def train_epoch_with_val(model, dataloader, criterion, optimizer, device='cuda'):
    model.train()
    running_loss = 0.0

    for images, labels, gender in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.long().to(device)  # required for CrossEntropyLoss
    
        optimizer.zero_grad()
        outputs = model(images)  # shape: (B, 2)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    
        running_loss += loss.item() * images.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


def train_epoch(model, dataloader, criterion, optimizer, device='cuda'):
    model.train()
    running_loss = 0.0

    for images, labels, genders in tqdm(dataloader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.long().to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

    return running_loss / len(dataloader.dataset)

In [8]:
def evaluate_groupwise(model, dataloader, device='cuda'):
    model.eval()
    group_correct = {
        "Blond Male": 0,
        "Blond Female": 0,
        "Non-Blond Male": 0,
        "Non-Blond Female": 0
    }
    group_total = {k: 0 for k in group_correct}

    with torch.no_grad():
        for images, labels, genders in tqdm(dataloader, desc="Evaluating", leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu()
            labels = labels.cpu()
            genders = genders.cpu()

            for pred, label, gender in zip(preds, labels, genders):
                group = ("Blond " if label == 1 else "Non-Blond ") + ("Male" if gender == 1 else "Female")
                group_total[group] += 1
                if pred == label:
                    group_correct[group] += 1

    for group in group_correct:
        correct = group_correct[group]
        total = group_total[group]
        acc = 100 * correct / max(1, total)
        print(f"{group}: Accuracy = {acc:.2f}% ({correct}/{total})")

    overall_acc = 100 * sum(group_correct.values()) / max(1, sum(group_total.values()))
    print(f"\nOverall Accuracy = {overall_acc:.2f}%")
    return overall_acc

In [47]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 2)
model = model.to('cuda')

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [None]:
num_epochs = 5
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device=device)
    print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}")

# Save the final model
torch.save(model.state_dict(), "resnet18_blond_classifier_synthetic_sampled_final.pth")

In [None]:
import torch.nn as nn
from tqdm import tqdm
import torchvision.models as models

model_test = models.resnet18()
model_test.fc = nn.Linear(model_test.fc.in_features, 2)  # 2 outputs for CE

checkpoint_path = "resnet18_blond_classifier_synthetic_sampled_final.pth"
model_test.load_state_dict(torch.load(checkpoint_path, map_location='cuda'))

model_test = model_test.to('cuda')

evaluate_groupwise(model_test, test_loader, device='cuda')

## Drift Visualization

### Embeddings

In [None]:
import torch
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms as transforms
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load pretrained feature extractor (e.g., ResNet18 without final FC layer)
resnet = models.resnet18(pretrained=True)
resnet.fc = torch.nn.Identity()  # remove final classification layer
resnet.eval().cuda()


In [None]:
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, num_workers=0)
synthetic_loader = DataLoader(synthetic_dataset, batch_size=64, shuffle=True, num_workers=0)

In [None]:
def get_embeddings(model, dataloader, max_samples=2000):
    features, labels, genders = [], [], []
    count = 0

    with torch.no_grad():
        for imgs, lbls, gdrs in dataloader:
            imgs = imgs.cuda()
            emb = model(imgs).cpu().numpy()
            features.append(emb)
            labels.extend(lbls.numpy())
            genders.extend(gdrs.numpy())
            count += len(imgs)
            if count >= max_samples:
                break

    return np.concatenate(features), np.array(labels), np.array(genders)

In [None]:
real_feats, real_lbls, real_genders = get_embeddings(resnet, train_loader)
synth_feats, synth_lbls, synth_genders = get_embeddings(resnet, synthetic_loader)

X = np.vstack([real_feats, synth_feats])
y = np.concatenate([real_lbls, synth_lbls])
g = np.concatenate([real_genders, synth_genders])
domain = ['real'] * len(real_feats) + ['synthetic'] * len(synth_feats)

# Run t-SNE
tsne = TSNE(n_components=2, perplexity=30, random_state=42)
X_2d = tsne.fit_transform(X)

# Plot
plt.figure(figsize=(10, 6))
sns.scatterplot(x=X_2d[:, 0], y=X_2d[:, 1], hue=domain, style=g.astype(int), alpha=0.7)
plt.title("t-SNE: Real vs Synthetic - Gender Style Overlay")
plt.legend(title="Domain / Gender")
plt.show()


In [None]:
import pandas as pd
df = pd.DataFrame({
    'x': X_2d[:, 0],
    'y': X_2d[:, 1],
    'domain': domain,
    'gender': g,
    'label': y
})

# Map subgroup combinations to readable names
subgroup_map = {
    (0, 0): 'Non-Blond Female',
    (0, 1): 'Non-Blond Male',
    (1, 0): 'Blond Female',
    (1, 1): 'Blond Male'
}
df['subgroup'] = df.apply(lambda row: subgroup_map[(row['label'], row['gender'])], axis=1)

# --- Plot 1: All Samples Colored by Subgroup (Hair + Gender) ---
plt.figure(figsize=(10, 6))
sns.scatterplot(data=df, x='x', y='y', hue='subgroup', style='domain', alpha=0.7)
plt.title("Plot 1: All Samples - Hair + Gender")
plt.legend(title="Subgroup")
plt.show()

# --- Plot 2: Real Only ---
plt.figure(figsize=(10, 6))
sns.scatterplot(data=df[df['domain'] == 'real'], x='x', y='y', hue='subgroup', alpha=0.7)
plt.title("Plot 2: Real Only - Hair + Gender")
plt.legend(title="Subgroup")
plt.show()

# --- Plot 3: Synthetic Only ---
plt.figure(figsize=(10, 6))
sns.scatterplot(data=df[df['domain'] == 'synthetic'], x='x', y='y', hue='subgroup', alpha=0.7)
plt.title("Plot 3: Synthetic Only - Hair + Gender")
plt.legend(title="Subgroup")
plt.show()

# --- Plots 4–7: Real vs Synthetic for each Subgroup ---
for i, ((label_val, gender_val), name) in enumerate(subgroup_map.items(), start=4):
    subset = df[(df['label'] == label_val) & (df['gender'] == gender_val)]
    plt.figure(figsize=(8, 5))
    sns.scatterplot(data=subset, x='x', y='y', hue='domain', alpha=0.7)
    plt.title(f"Plot {i}: Real vs Synthetic - {name}")
    plt.legend(title="Domain")
    plt.show()



## DANN

In [9]:
from torch.autograd import Function

class GradReverse(Function):
    @staticmethod
    def forward(ctx, x, lambd):
        ctx.lambd = lambd
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.lambd * grad_output, None

def grad_reverse(x, lambd=1.0):
    return GradReverse.apply(x, lambd)

In [10]:
import torch.nn as nn
import torchvision.models as models

class DANN(nn.Module):
    def __init__(self, feature_dim=512, num_classes=2):
        super().__init__()
        # Load pretrained resnet backbone
        self.backbone = models.resnet18(pretrained=False)
        self.backbone.fc = nn.Identity()  # Remove final FC
        self.feature_dim = feature_dim

        # Classification head
        self.classifier = nn.Linear(feature_dim, num_classes)

        # Domain discriminator
        self.domain_classifier = nn.Sequential(
            nn.Linear(feature_dim, 100),
            nn.ReLU(),
            nn.Linear(100, 2)  # 2 domains: synthetic, real
        )

    def forward(self, x, lambd=0.0):
        features = self.backbone(x)
        class_logits = self.classifier(features)
        reversed_features = grad_reverse(features, lambd)
        domain_logits = self.domain_classifier(reversed_features)
        return class_logits, domain_logits


In [11]:
from itertools import cycle
from tqdm import tqdm
import torch.nn.functional as F
import gc

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

pretrained = models.resnet18()
pretrained.fc = nn.Linear(pretrained.fc.in_features, 2)
pretrained.load_state_dict(torch.load("/kaggle/input/resnet18_blond_classifier_synthetic/pytorch/default/1/resnet18_blond_classifier_synthetic2.pth", map_location='cpu'))

model = DANN()

# Load backbone weights (everything except fc)
dann_backbone_dict = model.backbone.state_dict()
pretrained_backbone_dict = {
    k: v for k, v in pretrained.state_dict().items()
    if k in dann_backbone_dict and not k.startswith('fc')
}
model.backbone.load_state_dict(pretrained_backbone_dict)

# Load classifier weights
model.classifier.load_state_dict({
    "weight": pretrained.state_dict()["fc.weight"],
    "bias": pretrained.state_dict()["fc.bias"]
})

model = model.to('cuda')


optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion_cls = nn.CrossEntropyLoss()
criterion_domain = nn.CrossEntropyLoss()

num_epochs = 5

for epoch in range(num_epochs):
    model.train()
    total_loss, total_cls_loss, total_domain_loss = 0, 0, 0
    num_batches = 0

    target_iter = cycle(test_loader)
    progress_bar = tqdm(synthetic_loader, desc=f"Epoch {epoch+1}")

    for (xs, ys, _), (xt, _, _) in zip(progress_bar, target_iter):
        xs, ys = xs.to(device), ys.long().to(device)
        xt = xt.to(device)

        domain_s = torch.zeros(xs.size(0)).long().to(device)
        domain_t = torch.ones(xt.size(0)).long().to(device)

        x = torch.cat([xs, xt], dim=0)
        class_logits, domain_logits = model(x, lambd=1.0)

        class_preds = class_logits[:xs.size(0)]
        domain_preds = domain_logits

        loss_cls = criterion_cls(class_preds, ys)
        loss_domain = criterion_domain(domain_preds, torch.cat([domain_s, domain_t]))
        loss = loss_cls + 0.1 * loss_domain

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_cls_loss += loss_cls.item()
        total_domain_loss += loss_domain.item()
        num_batches += 1

        # Optionally update tqdm live bar
        progress_bar.set_postfix({
            "Last Total Loss": f"{loss.item():.2f}",
            "Last Cls Loss": f"{loss_cls.item():.2f}",
            "Last Domain Loss": f"{loss_domain.item():.2f}"
        })

    # Epoch-wise averaged loss
    avg_total_loss = total_loss / num_batches
    avg_cls_loss = total_cls_loss / num_batches
    avg_domain_loss = total_domain_loss / num_batches

    print(f"[Epoch {epoch+1}] Avg Total Loss: {avg_total_loss:.4f}, "
          f"Cls Loss: {avg_cls_loss:.4f}, Domain Loss: {avg_domain_loss:.4f}")

    # Clear memory between epochs
    torch.cuda.empty_cache()
    gc.collect()


Epoch 1: 100%|██████████| 619/619 [11:58<00:00,  1.16s/it, Last Total Loss=0.30, Last Cls Loss=0.25, Last Domain Loss=0.58]


[Epoch 1] Avg Total Loss: 0.2892, Cls Loss: 0.2242, Domain Loss: 0.6493


Epoch 2: 100%|██████████| 619/619 [07:40<00:00,  1.35it/s, Last Total Loss=0.17, Last Cls Loss=0.07, Last Domain Loss=1.08]


[Epoch 2] Avg Total Loss: 0.1959, Cls Loss: 0.1236, Domain Loss: 0.7227


Epoch 3: 100%|██████████| 619/619 [07:35<00:00,  1.36it/s, Last Total Loss=0.10, Last Cls Loss=0.03, Last Domain Loss=0.72]


[Epoch 3] Avg Total Loss: 0.1475, Cls Loss: 0.0741, Domain Loss: 0.7333


Epoch 4: 100%|██████████| 619/619 [07:40<00:00,  1.34it/s, Last Total Loss=0.09, Last Cls Loss=0.02, Last Domain Loss=0.65]


[Epoch 4] Avg Total Loss: 0.1174, Cls Loss: 0.0484, Domain Loss: 0.6897


Epoch 5: 100%|██████████| 619/619 [07:42<00:00,  1.34it/s, Last Total Loss=0.09, Last Cls Loss=0.02, Last Domain Loss=0.64]

[Epoch 5] Avg Total Loss: 0.1118, Cls Loss: 0.0406, Domain Loss: 0.7115





In [13]:
from tqdm import tqdm
import torch

def evaluate_groupwise(model, dataloader, device='cuda'):
    model.eval()
    group_correct = {
        "Blond Male": 0,
        "Blond Female": 0,
        "Non-Blond Male": 0,
        "Non-Blond Female": 0
    }
    group_total = {k: 0 for k in group_correct}

    with torch.no_grad():
        for images, labels, genders in tqdm(dataloader, desc="Evaluating", leave=False):
            images = images.to(device)
            labels = labels.to(device)
            genders = genders.to(device)

            class_logits, _ = model(images)
            preds = torch.argmax(class_logits, dim=1)

            labels = labels.cpu()
            genders = genders.cpu()
            preds = preds.cpu()

            for pred, label, gender in zip(preds, labels, genders):
                group = ("Blond " if label == 1 else "Non-Blond ") + ("Male" if gender == 1 else "Female")
                group_total[group] += 1
                if pred == label:
                    group_correct[group] += 1

    for group in group_correct:
        correct = group_correct[group]
        total = group_total[group]
        acc = 100 * correct / max(1, total)
        print(f"{group}: Accuracy = {acc:.2f}% ({correct}/{total})")

    overall_acc = 100 * sum(group_correct.values()) / max(1, sum(group_total.values()))
    print(f"\nOverall Accuracy = {overall_acc:.2f}%")
    return overall_acc


In [14]:
evaluate_groupwise(model, val_loader, device='cuda')

                                                             

Blond Male: Accuracy = 89.56% (163/182)
Blond Female: Accuracy = 84.31% (2423/2874)
Non-Blond Male: Accuracy = 51.22% (4239/8276)
Non-Blond Female: Accuracy = 82.78% (7065/8535)

Overall Accuracy = 69.91%




69.91493431318267

In [15]:
torch.save(model.state_dict(), "resnet18_blond_classifier_domain_adapted.pth")

In [16]:
evaluate_groupwise(model, test_loader, device='cuda')

                                                             

Blond Male: Accuracy = 91.67% (165/180)
Blond Female: Accuracy = 83.27% (2065/2480)
Non-Blond Male: Accuracy = 51.61% (3889/7535)
Non-Blond Female: Accuracy = 85.06% (8308/9767)

Overall Accuracy = 72.27%




72.27231740306583