<a href="https://colab.research.google.com/github/TheS1n233/Project3-Automatic-Subgroup-Identifcation-andMitigation-of-Biases-of-ML-Models/blob/main/P3_baseline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install fairlearn

Collecting fairlearn
  Downloading fairlearn-0.12.0-py3-none-any.whl.metadata (7.0 kB)
Downloading fairlearn-0.12.0-py3-none-any.whl (240 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m240.0/240.0 kB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: fairlearn
Successfully installed fairlearn-0.12.0


**1.Setup and Imports**

In [None]:
import os
import random
import numpy as np
import pandas as pd
from PIL import Image
from collections import defaultdict
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision import models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from sklearn.metrics import accuracy_score
from fairlearn.metrics import MetricFrame
from tqdm.auto import tqdm
from google.colab import drive
drive.mount('/content/drive')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()
csv_path = "/content/drive/MyDrive/waterbirds_data/waterbirds_v1.0/metadata.csv"
image_root = "/content/drive/MyDrive/waterbirds_data/waterbirds_v1.0"

Mounted at /content/drive
Using device: cuda


**2.Dataset and Subgroup Labels**

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import os

class WaterbirdsDataset(Dataset):
    def __init__(self, df, image_root, transform=None):
        self.df = df.reset_index(drop=True)
        self.image_root = image_root
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.image_root, row['img_filename'])
        image = Image.open(img_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = row['y']
        subgroup = row['subgroup_id']
        return image, label, subgroup


# **3.DataLoader + Transforms**

In [None]:
def load_waterbirds_splits(csv_path):
    df = pd.read_csv(csv_path)
    df = df.rename(columns={'filename': 'img_filename'})
    df['subgroup_id'] = df['y'] * 2 + df['place']

    train_df = df[df['split'] == 0].copy()
    val_df   = df[df['split'] == 1].copy()
    test_df  = df[df['split'] == 2].copy()

    print(f"Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")
    return train_df, val_df, test_df


transform_train = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor()
])
transform_eval = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

train_df, val_df, test_df = load_waterbirds_splits(csv_path)

BATCH_SIZE = 64
NUM_WORKERS = 4
PIN_MEMORY = True

train_loader = DataLoader(
    WaterbirdsDataset(train_df, image_root, transform_train),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

eval_loader = DataLoader(
    WaterbirdsDataset(val_df, image_root, transform_eval),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)

test_loader = DataLoader(
    WaterbirdsDataset(test_df, image_root, transform_eval),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=PIN_MEMORY,
)



Train: 4795, Val: 1199, Test: 5794




**4.Model: ResNet-18**

In [None]:
def get_model():
    model = models.resnet18(weights='IMAGENET1K_V1')
    model.fc = nn.Linear(model.fc.in_features, 2)
    return model.to(device)


def train_epoch(model, loader, opt, criterion):
    model.train()
    tot = correct = loss_sum = 0

    for i, (x, y, g) in enumerate(loader):
        x, y = x.to(device), y.to(device)

        out = model(x)
        loss = criterion(out, y)

        opt.zero_grad()
        loss.backward()
        opt.step()

        loss_sum += loss.item() * y.size(0)
        pred = out.argmax(1)
        tot += y.size(0)
        correct += (pred == y).sum().item()

        # print every 10 batch
        if i % 10 == 0 or i == len(loader) - 1:
            print(f"Batch {i+1}/{len(loader)} - Loss: {loss.item():.4f}  - Acc: {correct/tot:.4f}")

    return loss_sum / tot, correct / tot


from fairlearn.metrics import MetricFrame
from sklearn.metrics  import accuracy_score
from tqdm import tqdm

def evaluate(model, loader):
    model.eval()
    labels, preds, groups = [], [], []

    with torch.no_grad():
        for x, y, g in tqdm(loader, desc='Eval', leave=False):
            x = x.to(device);  y = y.to(device)
            p = model(x).argmax(1).cpu()
            preds  += p.tolist()
            labels += y.cpu().tolist()
            groups += g.tolist()

    mf = MetricFrame(accuracy_score,
                     labels, preds,
                     sensitive_features=groups)

    return mf.overall, mf.by_group.min(), mf.overall - mf.by_group.min()


EPOCHS = 1
baseline = get_model()
opt = optim.Adam(baseline.parameters(), lr=1e-4)
ce = nn.CrossEntropyLoss()

for ep in range(EPOCHS):
    print(f"\n🔁 Starting Epoch {ep+1}/{EPOCHS}")
    l, a = train_epoch(baseline, train_loader, opt, ce)
    print(f"✅ Epoch {ep+1} done. Loss: {l:.4f}, Acc: {a:.4f}")

torch.save(baseline.state_dict(), '/content/baseline.pth')
base_metrics = evaluate(baseline, test_loader)
print('Baseline  overall %.3f  worst %.3f  gap %.3f' % base_metrics)


Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 64.9MB/s]



🔁 Starting Epoch 1/1


KeyboardInterrupt: 

**5. Divergence-Aware Reweighting**

In [None]:
baseline.eval(); preds, labels, groups = [],[],[]
with torch.no_grad():
    for x,y,g,_ in test_loader:
        preds += baseline(x.to(device)).argmax(1).cpu().tolist()
        labels+= y.tolist(); groups += g.tolist()

err_df = pd.DataFrame({'group': groups,
                       'is_err': (np.array(preds)!=np.array(labels)).astype(int)})
div_dict = err_df.groupby('group')['is_err'].mean().to_dict()
json.dump(div_dict, open('/content/divergence.json','w'))
print('Divergence per group:', div_dict)


** 6. Train Function**

In [None]:
train_loader_dw = DataLoader(
    WaterbirdsDS(train_df, tf_train, div_dict=div_dict),
    batch_size=BATCH, shuffle=True, num_workers=2, pin_memory=True)

mit_model = get_model()
opt_dw = optim.Adam(mit_model.parameters(), lr=1e-4)
λ = 1.0

for ep in range(EPOCHS):
    mit_model.train(); tot=correct=loss_sum=0
    for x,y,_,d in train_loader_dw:
        x,y,d = x.to(device), y.to(device), d.to(device)
        out = mit_model(x)
        loss = ce(out,y)*(1 + λ*d)
        loss = loss.mean()
        opt_dw.zero_grad(); loss.backward(); opt_dw.step()
        loss_sum += loss.item()*y.size(0)
        pred=out.argmax(1); tot+=y.size(0); correct += (pred==y).sum().item()
    print(f'[DW] Epoch {ep+1}  loss={loss_sum/tot:.4f}  acc={correct/tot:.3f}')

torch.save(mit_model.state_dict(), '/content/mitigation.pth')
mit_metrics = evaluate(mit_model, test_loader)
print('Mitigation overall %.3f  worst %.3f  gap %.3f' % mit_metrics)


**7. Evaluation**

In [None]:
from captum.attr import LayerGradCam

# Take a sample from the baseline that is misjudged and belongs to the worst group
baseline.eval()
sample_img, sample_heat = None, None
with torch.no_grad():
    for x,y,g,_ in test_loader:
        pred = baseline(x.to(device)).argmax(1).cpu()
        mask = (pred!=y) & (g==max(div_dict, key=div_dict.get))
        if mask.any():
            idx = torch.where(mask)[0][0].item()
            sample_img = x[idx]; true_label = y[idx].item()
            break

lgc = LayerGradCam(baseline, baseline.layer4)
attr = lgc(sample_img.unsqueeze(0).to(device), target=true_label)
heat = attr.squeeze().cpu().mean(0).numpy()
heat = (heat-heat.min())/(heat.max()-heat.min()+1e-8)

plt.figure(figsize=(6,3))
plt.subplot(1,2,1); plt.imshow(sample_img.permute(1,2,0)); plt.axis('off'); plt.title('Original')
plt.subplot(1,2,2); plt.imshow(sample_img.permute(1,2,0)); plt.imshow(heat,cmap='jet',alpha=0.5); plt.axis('off'); plt.title('Grad-CAM')
plt.show()

# Result Table and Histogram
import seaborn as sns
df_res = pd.DataFrame([
    dict(model='Baseline',   overall=base_metrics[0], worst=base_metrics[1], gap=base_metrics[2]),
    dict(model='DivWeight',  overall=mit_metrics[0],  worst=mit_metrics[1],  gap=mit_metrics[2]),
])
display(df_res)

plt.figure(figsize=(4,3))
sns.barplot(data=df_res, x='model', y='gap')
plt.ylabel('Accuracy Gap'); plt.title('Gap ↓ is better'); plt.show()


8. Run Everything

In [None]:
# STEP 8 - Run All
model = get_model()
train_model(model, train_loader, val_loader, epochs=10)
print("Validation Set:")
evaluate(model, val_loader)
print("Test Set:")
evaluate(model, test_loader)
