# Setup

In [None]:
pip install traker

In [2]:
from trak import TRAKer

def get_trak_matrix(
    train_dl, val_dl, model, ckpts, train_set_size, val_set_size, **kwargs
):
    if kwargs is None or kwargs.get("task") is None:
        task = "image_classification"
    else:
        task = kwargs.pop("task")

    traker = TRAKer(model=model, task=task, train_set_size=train_set_size, **kwargs)

    for model_id, checkpoint in enumerate(ckpts):
        traker.load_checkpoint(checkpoint, model_id=model_id)
        for batch in train_dl:
            batch = [x.cuda() for x in batch]
            # batch should be a tuple/list of inputs and labels
            traker.featurize(batch=batch, num_samples=batch[0].shape[0])

    traker.finalize_features()

    for model_id, checkpoint in enumerate(ckpts):
        traker.start_scoring_checkpoint(
            exp_name="test",
            checkpoint=checkpoint,
            model_id=model_id,
            num_targets=val_set_size,
        )
    for batch in val_dl:
        batch = [x.cuda() for x in batch]
        traker.score(batch=batch, num_samples=batch[0].shape[0])

    scores = traker.finalize_scores(exp_name="test")
    return scores


In [3]:
import torch
import numpy as np
from torch.nn import functional as F

class DDA:
    def __init__(
        self,
        model,
        checkpoints,
        train_dataloader,
        val_dataloader,
        group_indices,
        train_set_size=None,
        val_set_size=None,
        trak_scores=None,
        trak_kwargs=None,
        device="cuda",
    ) -> None:
        
        self.model = model
        self.checkpoints = checkpoints
        self.dataloaders = {"train": train_dataloader, "val": val_dataloader}
        self.group_indices = group_indices
        self.device = device

        if trak_scores is not None:
            self.trak_scores = trak_scores
        else:
            try:
                self.train_set_size = len(train_dataloader.dataset)
                self.val_set_size = len(val_dataloader.dataset)
            except AttributeError as e:
                print(
                    f"No dataset attribute found in train_dataloader or val_dataloader. {e}"
                )
                if train_set_size is None or val_set_size is None:
                    raise ValueError(
                        "train_set_size and val_set_size must be specified if "
                        "train_dataloader and val_dataloader do not have a "
                        "dataset attribute."
                    ) from e
                self.train_set_size = train_set_size
                self.val_set_size = val_set_size

            # Step 1: compute TRAK scores
            if trak_kwargs is not None:
                trak_scores = get_trak_matrix(
                    train_dl=self.dataloaders["train"],
                    val_dl=self.dataloaders["val"],
                    model=self.model,
                    ckpts=self.checkpoints,
                    train_set_size=self.train_set_size,
                    val_set_size=self.val_set_size,
                    **trak_kwargs,
                )
            else:
                trak_scores = get_trak_matrix(
                    train_dl=self.dataloaders["train"],
                    val_dl=self.dataloaders["val"],
                    model=self.model,
                    ckpts=self.checkpoints,
                    train_set_size=self.train_set_size,
                    val_set_size=self.val_set_size,
                )

            self.trak_scores = trak_scores

    def get_group_losses(self, model, val_dl, group_indices) -> list:
        losses = []
        model.eval()
        with torch.no_grad():
            for inputs, labels in val_dl:
                outputs = model(inputs.to(self.device))
                loss = F.cross_entropy(
                    outputs, labels.to(self.device), reduction="none"
                )
                losses.append(loss)
        losses = torch.cat(losses)

        n_groups = len(set(group_indices))
        group_losses = [losses[group_indices == i].mean() for i in range(n_groups)]
        return group_losses

    def compute_group_alignment_scores(self, trak_scores, group_indices, group_losses):
        n_groups = len(set(group_indices))
        S = np.array(trak_scores)
        g = [
            group_losses[i].cpu().numpy() * S[:, np.array(group_indices) == i].mean(axis=1)
            for i in range(n_groups)
        ]
        g = np.stack(g)
        group_alignment_scores = g.mean(axis=0)
        return group_alignment_scores

    def get_debiased_train_indices(
        self, group_alignment_scores, use_heuristic=True, num_to_discard=None
    ):
        if use_heuristic:
            return [i for i, score in enumerate(group_alignment_scores) if score >= 0]

        if num_to_discard is None:
            raise ValueError("num_to_discard must be specified if not using heuristic.")

        sorted_indices = sorted(
            range(len(group_alignment_scores)),
            key=lambda i: group_alignment_scores[i],
        )
        return sorted_indices[num_to_discard:]

    def debias(self, use_heuristic=True, num_to_discard=None):
        group_losses = self.get_group_losses(
            model=self.model,
            val_dl=self.dataloaders["val"],
            group_indices=self.group_indices,
        )

        group_alignment_scores = self.compute_group_alignment_scores(
            self.trak_scores, self.group_indices, group_losses
        )
        
        debiased_train_inds = self.get_debiased_train_indices(
            group_alignment_scores,
            use_heuristic=use_heuristic,
            num_to_discard=num_to_discard,
        )

        return debiased_train_inds


In [4]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# CelebA

In [5]:
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from PIL import Image
import pandas as pd
import numpy as np
from tqdm import tqdm

# Paths to CelebA images and metadata
celeba_images_path = "/kaggle/input/celeba-dataset/img_align_celeba/img_align_celeba"
partition_file = "/kaggle/input/celeba-dataset/list_eval_partition.csv"
attributes_file = "/kaggle/input/celeba-dataset/list_attr_celeba.csv"

# Function to get DataLoader for CelebA
def get_dataloader(
        batch_size=128, num_workers=4, split="train", shuffle=False, augment=True
    ):
    """
    Get DataLoader for the CelebA dataset with only 10% of the total dataset.
    """
    # Define transformations
    if augment:
        transforms_pipeline = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(),
                transforms.CenterCrop(178),
                transforms.Resize(128),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )
    else:
        transforms_pipeline = transforms.Compose(
            [
                transforms.CenterCrop(178),
                transforms.Resize(128),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ]
        )

    # Load partition and attributes
    partitions = pd.read_csv(partition_file)
    attributes = pd.read_csv(attributes_file)

    # Ensure attributes are binary
    attributes.iloc[:, 1:] = attributes.iloc[:, 1:].map(lambda x: 1 if x == 1 else 0)

    total_indices = len(attributes)
    reduced_indices = np.random.choice(attributes.index, size=total_indices // 10, replace=False)

    # Define young (attribute "Young") and old classes
    young_indices = attributes[attributes["Young"] == 1].index.intersection(reduced_indices)
    old_indices = attributes[attributes["Young"] == 0].index.intersection(reduced_indices)

    # Create the subset with a 4:1 ratio (young:old)
    num_old = len(old_indices)
    num_young = min(len(young_indices), num_old * 4)
    selected_young_indices = np.random.choice(young_indices, num_young, replace=False)
    selected_indices = np.concatenate([selected_young_indices, old_indices])

    # Split the subset based on train/val partitions
    dataset_split = "train" if split == "train" else "valid"
    if dataset_split == "train":
        selected_indices = partitions[
            (partitions["partition"] == 0) & partitions.index.isin(selected_indices)
        ].index
    else:
        selected_indices = partitions[
            (partitions["partition"] == 1) & partitions.index.isin(selected_indices)
        ].index

    # Custom Dataset class for CelebA
    class CelebADataset(torch.utils.data.Dataset):
        def __init__(self, indices, img_dir, attributes, transform=None):
            self.indices = indices
            self.img_dir = img_dir
            self.attributes = attributes
            self.transform = transform

        def __len__(self):
            return len(self.indices)

        def __getitem__(self, idx):
            img_index = self.indices[idx]
            img_name = self.attributes.iloc[img_index, 0]
            img_path = os.path.join(self.img_dir, img_name)

            # Load and preprocess the image
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)

            # Binary class label: young (1) or old (0)
            label = torch.tensor(self.attributes.iloc[img_index]["Young"], dtype=torch.long)
            return image, label

    # Create Dataset and DataLoader
    dataset = CelebADataset(
        indices=selected_indices,
        img_dir=celeba_images_path,
        attributes=attributes,
        transform=transforms_pipeline
    )

    loader = DataLoader(
        dataset=dataset, shuffle=shuffle, batch_size=batch_size, num_workers=num_workers
    )

    return loader, dataset

# Load pre-trained model
from torchvision.models import resnet18, ResNet18_Weights
model_before_mitigating = resnet18(weights=ResNet18_Weights.DEFAULT)
model_before_mitigating.fc = nn.Linear(model_before_mitigating.fc.in_features, 2)  # Binary classification (young or old)
model_before_mitigating = model_before_mitigating.cuda()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model_before_mitigating.parameters(), lr=0.001)

# Get DataLoaders
train_loader, train_dataset = get_dataloader(batch_size=32, split="train", shuffle=True)
val_loader, val_dataset = get_dataloader(batch_size=32, split="val", shuffle=False, augment=False)

# Training Loop
num_epochs = 1
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    model_before_mitigating.train()
    epoch_loss = 0.0
    for images, labels in tqdm(train_loader, desc="Training"):
        images = images.cuda()
        labels = labels.cuda()

        # Forward pass
        outputs = model_before_mitigating(images)
        loss = criterion(outputs, labels)
        epoch_loss += loss.item()

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    avg_loss = epoch_loss / len(train_loader)
    print(f"Training Loss: {avg_loss:.4f}")

    # Validation
    model_before_mitigating.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images = images.cuda()
            labels = labels.cuda()
            outputs = model_before_mitigating(images)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    print(f"Validation Accuracy: {accuracy:.2f}%")

# Final Output
print("Training and evaluation completed.")

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 187MB/s]



Epoch 1/1


Training: 100%|██████████| 512/512 [00:39<00:00, 13.02it/s]


Training Loss: 0.3867


Validation: 100%|██████████| 63/63 [00:04<00:00, 13.56it/s]

Validation Accuracy: 83.65%
Training and evaluation completed.





In [6]:
from sklearn.metrics import accuracy_score
import numpy as np

def evaluate_worst_group_accuracy(model, val_loader, group_inds, device="cuda"):
    model.eval()  # Set model to evaluation mode
    group_preds = {i: [] for i in set(group_inds)}
    group_labels = {i: [] for i in set(group_inds)}

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader, desc="Evaluating WGA")):
            inputs = inputs.to(device)
            labels = labels.to(device)  # Remove `.argmax(dim=1)` since labels are not one-hot encoded

            # Predict using the model
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            # Assign predictions and labels to the respective group
            batch_start = batch_idx * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_inds[batch_start:batch_end]

            for i, group in enumerate(batch_groups):
                group_preds[group].append(preds[i])  # Add single prediction
                group_labels[group].append(labels.cpu().numpy()[i])  # Add single label

    group_accuracies = {}
    for group in group_preds.keys():
        if len(group_preds[group]) == 0 or len(group_labels[group]) == 0:
            # Skip groups with no samples
            group_accuracies[group] = 0.0
            continue

        # Convert lists to arrays
        preds = np.array(group_preds[group])
        truths = np.array(group_labels[group])
        group_accuracies[group] = accuracy_score(truths, preds)

    # Print all group accuracies
    for group, acc in group_accuracies.items():
        print(f"Group {group} Accuracy: {acc:.4f}")

    # Find the worst group accuracy
    worst_group_accuracy = min(group_accuracies.values())
    return worst_group_accuracy, group_accuracies

In [17]:
# Load attributes
attributes = pd.read_csv(attributes_file)
attributes.iloc[:, 1:] = attributes.iloc[:, 1:].map(lambda x: 1 if x == 1 else 0)

# Define subgroups based on Young and Male attributes
def define_subgroups(row):
    if row["Young"] == 1 and row["Male"] == 1:
        return "young-male"
    elif row["Young"] == 1 and row["Male"] == 0:
        return "young-female"
    elif row["Young"] == 0 and row["Male"] == 1:
        return "old-male"
    elif row["Young"] == 0 and row["Male"] == 0:
        return "old-female"

# Assign subgroup labels
attributes["subgroup"] = attributes.apply(define_subgroups, axis=1)

# Map subgroup names to numerical indices
subgroup_mapping = {name: i for i, name in enumerate(sorted(attributes["subgroup"].unique()))}
attributes["group_index"] = attributes["subgroup"].map(subgroup_mapping)

# Align group indices with the reduced validation dataset
val_indices = val_loader.dataset.indices  # Indices of the validation subset
group_labels = attributes.loc[val_indices, "group_index"].values  # Get subgroup indices for validation set
group_inds = list(group_labels)  # Convert to a list for compatibility if needed

# Print subgroup distribution and sample group indices for validation
print("Subgroup Distribution in Validation Set:")
print(attributes.loc[val_indices, "subgroup"].value_counts())
print(f"Sample Group Indices: {group_inds[:10]}")

Subgroup Distribution in Validation Set:
subgroup
young-female    931
young-male      551
old-male        315
old-female      191
Name: count, dtype: int64
Sample Group Indices: [3, 3, 2, 0, 1, 2, 1, 1, 2, 3]


**Calculating Fairness Metrics for Young and Old Groups**

In [18]:
from sklearn.metrics import accuracy_score, confusion_matrix
import numpy as np

# Function to evaluate Group Accuracies
def evaluate_group_accuracies(model, val_loader, group_labels, device="cuda"):
    model.eval()
    group_preds = {g: [] for g in set(group_labels)}
    group_truths = {g: [] for g in set(group_labels)}

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Evaluating Group Accuracies")):
            images = images.to(device)
            labels = labels.to(device)  # Remove `.argmax(dim=1)` here

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            # Assign predictions and truths to respective groups
            batch_start = i * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_labels[batch_start:batch_end]

            for j, group in enumerate(batch_groups):
                group_preds[group].append(preds[j])
                group_truths[group].append(labels.cpu().numpy()[j])

    group_accuracies = {}
    for group in group_preds:
        if len(group_preds[group]) == 0:
            group_accuracies[group] = 0.0
        else:
            preds = np.array(group_preds[group])
            truths = np.array(group_truths[group])
            group_accuracies[group] = accuracy_score(truths, preds)

    # Print group accuracies
    for group, acc in group_accuracies.items():
        print(f"Group {group} Accuracy: {acc:.4f}")
    
    return group_accuracies

# Function to evaluate Demographic Parity (DP)
def evaluate_demographic_parity(model, val_loader, group_labels, device="cuda"):
    model.eval()
    group_pprs = {g: [] for g in set(group_labels)}

    with torch.no_grad():
        for i, (images, _) in enumerate(tqdm(val_loader, desc="Evaluating Demographic Parity")):
            images = images.to(device)
            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            batch_start = i * val_loader.batch_size
            batch_end = batch_start + len(preds)
            batch_groups = group_labels[batch_start:batch_end]

            for j, group in enumerate(batch_groups):
                group_pprs[group].append(preds[j])

    ppr_disparities = {}
    for group in group_pprs:
        group_positive_rate = np.mean(group_pprs[group])
        ppr_disparities[group] = group_positive_rate

    # Print group PPRs
    for group, ppr in ppr_disparities.items():
        print(f"Group {group} PPR: {ppr:.4f}")
    
    return ppr_disparities

# Function to evaluate Equal Opportunity (EO)
def evaluate_equal_opportunity(model, val_loader, group_labels, device="cuda"):
    model.eval()
    group_tprs = {g: [] for g in set(group_labels)}

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Evaluating Equal Opportunity")):
            images = images.to(device)
            labels = labels.to(device)  # Remove `.argmax(dim=1)` here

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            batch_start = i * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_labels[batch_start:batch_end]

            for j, group in enumerate(batch_groups):
                tp = (preds[j] == 1 and labels[j].cpu().numpy() == 1)
                actual_positive = labels[j].cpu().numpy() == 1
                group_tprs[group].append(tp / (actual_positive + 1e-8))  # Avoid division by zero

    tpr_disparities = {}
    for group in group_tprs:
        tpr_disparities[group] = np.mean(group_tprs[group])

    # Print group TPRs
    for group, tpr in tpr_disparities.items():
        print(f"Group {group} TPR: {tpr:.4f}")
    
    return tpr_disparities

# Function to evaluate Equalized Odds (EOd)
def evaluate_equalized_odds(model, val_loader, group_labels, device="cuda"):
    model.eval()
    group_tprs = {g: [] for g in set(group_labels)}
    group_fprs = {g: [] for g in set(group_labels)}

    with torch.no_grad():
        for i, (images, labels) in enumerate(tqdm(val_loader, desc="Evaluating Equalized Odds")):
            images = images.to(device)
            labels = labels.to(device)  # Remove `.argmax(dim=1)` here

            outputs = model(images)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            batch_start = i * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_labels[batch_start:batch_end]

            for j, group in enumerate(batch_groups):
                tp = (preds[j] == 1 and labels[j].cpu().numpy() == 1)
                fp = (preds[j] == 1 and labels[j].cpu().numpy() == 0)
                actual_positive = labels[j].cpu().numpy() == 1
                actual_negative = labels[j].cpu().numpy() == 0

                group_tprs[group].append(tp / (actual_positive + 1e-8))  # Avoid division by zero
                group_fprs[group].append(fp / (actual_negative + 1e-8))  # Avoid division by zero

    tpr_disparities = {}
    fpr_disparities = {}
    for group in group_tprs:
        tpr_disparities[group] = np.mean(group_tprs[group])
        fpr_disparities[group] = np.mean(group_fprs[group])

    # Print group TPRs and FPRs
    for group in group_tprs:
        print(f"Group {group} TPR: {tpr_disparities[group]:.4f}, FPR: {fpr_disparities[group]:.4f}")
    
    return tpr_disparities, fpr_disparities

In [19]:
wga, group_accuracies = evaluate_worst_group_accuracy(model_before_mitigating, val_loader, group_labels)
dp_rates = evaluate_demographic_parity(model_before_mitigating, val_loader, group_labels)
eo_tprs = evaluate_equal_opportunity(model_before_mitigating, val_loader, group_labels)
tpr_disparities, fpr_disparities = evaluate_equalized_odds(model_before_mitigating, val_loader, group_labels)

Evaluating WGA: 100%|██████████| 63/63 [00:02<00:00, 27.02it/s]


Group 0 Accuracy: 0.4450
Group 1 Accuracy: 0.8190
Group 2 Accuracy: 0.9452
Group 3 Accuracy: 0.7985


Evaluating Demographic Parity: 100%|██████████| 63/63 [00:01<00:00, 31.75it/s]


Group 0 PPR: 0.5550
Group 1 PPR: 0.1810
Group 2 PPR: 0.9452
Group 3 PPR: 0.7985


Evaluating Equal Opportunity: 100%|██████████| 63/63 [00:02<00:00, 31.32it/s]


Group 0 TPR: 0.0000
Group 1 TPR: 0.0000
Group 2 TPR: 0.9452
Group 3 TPR: 0.7985


Evaluating Equalized Odds: 100%|██████████| 63/63 [00:02<00:00, 31.23it/s]

Group 0 TPR: 0.0000, FPR: 0.5550
Group 1 TPR: 0.0000, FPR: 0.1810
Group 2 TPR: 0.9452, FPR: 0.0000
Group 3 TPR: 0.7985, FPR: 0.0000





# Debiasing with D3M

In [20]:
print('YOYO')
ckpts = [model_before_mitigating.state_dict()]
dda = DDA(model_before_mitigating, ckpts, train_loader, val_loader, group_inds)

YOYO


Finalizing features for all model IDs..: 100%|██████████| 1/1 [00:00<00:00, 4405.78it/s]
Finalizing scores for all model IDs..: 100%|██████████| 1/1 [00:00<00:00,  6.70it/s]


In [22]:
# debiased_inds = dda.debias(use_heuristic=False, num_to_discard=100)
debiased_inds = dda.debias(use_heuristic=True)
len(debiased_inds)

16275

In [16]:
import copy

deep_copy_model = copy.deepcopy(model_before_mitigating)

# Equal Opportunity

In [26]:
def calculate_tpr(labels, preds):
    tp = np.sum((preds == 1) & (labels == 1))
    fn = np.sum((preds == 0) & (labels == 1))
    return tp / (tp + fn) if (tp + fn) > 0 else 0.0

def evaluate_equal_opportunity(model, val_loader, group_inds, device="cuda"):
    model.eval()
    group_preds = {i: [] for i in set(group_inds)}
    group_labels = {i: [] for i in set(group_inds)}

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            batch_start = batch_idx * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_inds[batch_start:batch_end]

            for i, group in enumerate(batch_groups):
                group_preds[group].append(preds[i])
                group_labels[group].append(labels.cpu().numpy()[i])

    group_tprs = {}
    for group in group_preds.keys():
        preds = np.array(group_preds[group])
        labels = np.array(group_labels[group])
        group_tprs[group] = calculate_tpr(labels, preds)

    min_tpr = min(group_tprs.values())
    max_tpr = max(group_tprs.values())
    tpr_disparity = max_tpr - min_tpr

    return group_tprs, tpr_disparity

In [27]:
group_tprs, tpr_disparity = evaluate_equal_opportunity(deep_copy_model, val_loader, group_inds)
print(f"Group TPRs: {group_tprs}")
print(f"TPR Disparity: {tpr_disparity:.4f}")

Group TPRs: {0: 0.9152671755725191, 1: 0.9867026802410139, 2: 0.0, 3: 0.0}
TPR Disparity: 0.9867


# Equal Odds

In [28]:
def calculate_fpr(labels, preds):
    fp = np.sum((preds == 1) & (labels == 0))
    tn = np.sum((preds == 0) & (labels == 0))
    return fp / (fp + tn) if (fp + tn) > 0 else 0.0

def evaluate_equalized_odds(model, val_loader, group_inds, device="cuda"):
    model.eval()
    group_preds = {i: [] for i in set(group_inds)}
    group_labels = {i: [] for i in set(group_inds)}

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            batch_start = batch_idx * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_inds[batch_start:batch_end]

            for i, group in enumerate(batch_groups):
                group_preds[group].append(preds[i])
                group_labels[group].append(labels.cpu().numpy()[i])

    group_tprs, group_fprs = {}, {}
    for group in group_preds.keys():
        preds = np.array(group_preds[group])
        labels = np.array(group_labels[group])
        group_tprs[group] = calculate_tpr(labels, preds)
        group_fprs[group] = calculate_fpr(labels, preds)

    tpr_disparity = max(group_tprs.values()) - min(group_tprs.values())
    fpr_disparity = max(group_fprs.values()) - min(group_fprs.values())

    return group_tprs, group_fprs, tpr_disparity, fpr_disparity

In [29]:
group_tprs, group_fprs, tpr_disparity, fpr_disparity = evaluate_equalized_odds(deep_copy_model, val_loader, group_inds)
print(f"Group TPRs: {group_tprs}")
print(f"Group FPRs: {group_fprs}")
print(f"TPR Disparity: {tpr_disparity:.4f}")
print(f"FPR Disparity: {fpr_disparity:.4f}")

Group TPRs: {0: 0.9152671755725191, 1: 0.9867026802410139, 2: 0.0, 3: 0.0}
Group FPRs: {0: 0.0, 1: 0.0, 2: 0.2816989381636477, 3: 0.6818181818181818}
TPR Disparity: 0.9867
FPR Disparity: 0.6818


# Demographic Parity

In [30]:
def calculate_ppr(preds):
    return np.mean(preds)


def evaluate_demographic_parity(model, val_loader, group_inds, device="cuda"):
    """
    Evaluate Demographic Parity.
    Ensures PPR is correctly normalized as probabilities.
    """
    model.eval()
    group_preds = {i: [] for i in set(group_inds)}

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(val_loader):
            inputs = inputs.to(device)
            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()

            # Get the group indices for the current batch
            batch_start = batch_idx * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_inds[batch_start:batch_end]

            # Assign predictions to the corresponding group
            for i, group in enumerate(batch_groups):
                group_preds[group].append(preds[i])

    group_pprs = {}
    for group in group_preds.keys():
        # Flatten the predictions list for each group and normalize
        preds = np.array(group_preds[group]).flatten()
        positive_preds = (preds == 1).sum()  # Count positive predictions
        total_preds = len(preds)  # Total number of predictions
        group_pprs[group] = positive_preds / total_preds if total_preds > 0 else 0.0

    # Calculate the disparity in PPRs across groups
    min_ppr = min(group_pprs.values())
    max_ppr = max(group_pprs.values())
    ppr_disparity = max_ppr - min_ppr

    return group_pprs, ppr_disparity

In [31]:
group_pprs, ppr_disparity = evaluate_demographic_parity(deep_copy_model, val_loader, group_inds)
print(f"Group PPRs: {group_pprs}")
print(f"PPR Disparity: {ppr_disparity:.4f}")

Group PPRs: {0: 0.9152671755725191, 1: 0.9867026802410139, 2: 0.2816989381636477, 3: 0.6818181818181818}
PPR Disparity: 0.7050


In [32]:
from sklearn.metrics import confusion_matrix
import numpy as np
from tqdm import tqdm

def calculate_fnr_fpr(model, val_loader, group_inds, device="cuda"):
    """
    Calculate False Negative Rate (FNR) and False Positive Rate (FPR) for each group.
    """
    model.eval()  # Set model to evaluation mode
    group_metrics = {g: {"FN": 0, "FP": 0, "TP": 0, "TN": 0} for g in set(group_inds)}  # Metrics for each group

    with torch.no_grad():
        for batch_idx, (inputs, labels) in enumerate(tqdm(val_loader, desc="Calculating FNR and FPR")):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            preds = torch.argmax(outputs, dim=1).cpu().numpy()
            labels = labels.cpu().numpy()

            # Get the groups for the current batch
            batch_start = batch_idx * val_loader.batch_size
            batch_end = batch_start + len(labels)
            batch_groups = group_inds[batch_start:batch_end]

            for i, group in enumerate(batch_groups):
                if group not in group_metrics:
                    continue  # Skip if group is not defined
                
                # Update confusion matrix components
                if labels[i] == 1 and preds[i] == 0:  # False Negative
                    group_metrics[group]["FN"] += 1
                elif labels[i] == 0 and preds[i] == 1:  # False Positive
                    group_metrics[group]["FP"] += 1
                elif labels[i] == 1 and preds[i] == 1:  # True Positive
                    group_metrics[group]["TP"] += 1
                elif labels[i] == 0 and preds[i] == 0:  # True Negative
                    group_metrics[group]["TN"] += 1

    # Calculate FNR and FPR for each group
    group_fnr_fpr = {}
    for group, metrics in group_metrics.items():
        fn = metrics["FN"]
        fp = metrics["FP"]
        tp = metrics["TP"]
        tn = metrics["TN"]

        actual_positives = tp + fn
        actual_negatives = tn + fp

        fnr = fn / (actual_positives + 1e-8) if actual_positives > 0 else 0.0
        fpr = fp / (actual_negatives + 1e-8) if actual_negatives > 0 else 0.0

        group_fnr_fpr[group] = {"FNR": fnr, "FPR": fpr}

    # Print FNR and FPR for each group
    print("\nGroup FNR and FPR:")
    for group, metrics in group_fnr_fpr.items():
        print(f"Group {group}: FNR = {metrics['FNR']:.4f}, FPR = {metrics['FPR']:.4f}")

    # Calculate and print disparities
    fnr_values = [metrics["FNR"] for metrics in group_fnr_fpr.values()]
    fpr_values = [metrics["FPR"] for metrics in group_fnr_fpr.values()]
    fnr_disparity = max(fnr_values) - min(fnr_values)
    fpr_disparity = max(fpr_values) - min(fpr_values)

    print(f"\nFNR Disparity (Max - Min): {fnr_disparity:.4f}")
    print(f"FPR Disparity (Max - Min): {fpr_disparity:.4f}")

    return group_fnr_fpr, fnr_disparity, fpr_disparity


# Example usage
group_fnr_fpr, fnr_disparity, fpr_disparity = calculate_fnr_fpr(model_before_mitigating, val_loader, group_inds)

Calculating FNR and FPR: 100%|██████████| 312/312 [00:09<00:00, 34.44it/s]


Group FNR and FPR:
Group 0: FNR = 0.0847, FPR = 0.0000
Group 1: FNR = 0.0133, FPR = 0.0000
Group 2: FNR = 0.0000, FPR = 0.2817
Group 3: FNR = 0.0000, FPR = 0.6818

FNR Disparity (Max - Min): 0.0847
FPR Disparity (Max - Min): 0.6818





# Machine Unlearning

In [18]:
harmful_indices = debiased_inds

In [35]:
def remove_influence(model, dataloader, harmful_indices, factor, device="cuda"):
    model.eval()
    harmful_dataset = torch.utils.data.Subset(dataloader.dataset, harmful_indices)
    harmful_loader = torch.utils.data.DataLoader(harmful_dataset, batch_size=1)

    for inputs, labels in harmful_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        outputs = model(inputs)

        loss = torch.nn.functional.cross_entropy(outputs, labels)
        grads = torch.autograd.grad(loss, model.parameters(), retain_graph=True)

        with torch.no_grad():
            for param, grad in zip(model.parameters(), grads):
                param -= grad * factor

    return model

results ={'factor':[], 'model':[], 'min':[], 'max':[], 'gap':[]}
factors = np.linspace(0.00001, 0.0001, 2)

for factor in factors:
    newdeepmodel = copy.deepcopy(deep_copy_model)
    m = remove_influence(newdeepmodel, train_loader, harmful_indices, factor, device="cuda")
    wga, group_accs = evaluate_worst_group_accuracy(m, val_loader, group_inds, device="cuda")
    current_gap = (max(group_accs.values()) - wga)
    results['model'].append(m)
    results['min'].append(wga)
    results['max'].append(max(group_accs.values()))
    results['gap'].append(current_gap)
    results['factor'].append(factor)

Evaluating WGA: 100%|██████████| 312/312 [00:10<00:00, 30.69it/s]


Group 0 Accuracy: 0.9080
Group 1 Accuracy: 0.9838
Group 2 Accuracy: 0.7408
Group 3 Accuracy: 0.3734


Evaluating WGA: 100%|██████████| 312/312 [00:09<00:00, 31.40it/s]

Group 0 Accuracy: 0.9466
Group 1 Accuracy: 0.9877
Group 2 Accuracy: 0.6921
Group 3 Accuracy: 0.3690





In [36]:
import pandas as pd

df = pd.DataFrame(results).sort_values('factor')
df

Unnamed: 0,factor,model,min,max,gap
0,1e-05,"ResNet(\n (conv1): Conv2d(3, 64, kernel_size=...",0.373377,0.983794,0.610417
1,0.0001,"ResNet(\n (conv1): Conv2d(3, 64, kernel_size=...",0.369048,0.987742,0.618694


Now, it's time to investigate what are the best approaches to machine unlearning and how can we formulate that.

What are the other approaches to machine unlearning?

# Fair Pruning

In [None]:
import torch
import copy

def fair_pruning(model, dataloader, harmful_indices, threshold=0.01, device="cuda"):
    model.eval()
    pruned_model = copy.deepcopy(model)
    harmful_dataset = torch.utils.data.Subset(dataloader.dataset, harmful_indices)
    harmful_loader = torch.utils.data.DataLoader(harmful_dataset, batch_size=1)

    parameter_gradients = []
    for inputs, labels in tqdm(harmful_loader, desc="Calculating Gradients"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = pruned_model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        grads = torch.autograd.grad(loss, pruned_model.parameters(), retain_graph=True)
        parameter_gradients.append([grad.clone() for grad in grads])

    with torch.no_grad():
        for param, grads in zip(pruned_model.parameters(), zip(*parameter_gradients)):
            mean_grad = torch.mean(torch.stack(grads), dim=0)
            param[torch.abs(mean_grad) < threshold] = 0.0

    return pruned_model

pruned_model = fair_pruning(model_before_mitigating, train_loader, harmful_indices, threshold=0.01)

wga, group_accs = evaluate_worst_group_accuracy(pruned_model, val_loader, group_inds, device="cuda")

# Differentially Private Influence Functions for Unlearning

In [None]:
import torch
import numpy as np
import copy

def dp_influence_unlearning(model, dataloader, harmful_indices, epsilon=1.0, delta=1e-5, device="cuda"):
    model.eval()
    updated_model = copy.deepcopy(model)
    harmful_dataset = torch.utils.data.Subset(dataloader.dataset, harmful_indices)
    harmful_loader = torch.utils.data.DataLoader(harmful_dataset, batch_size=1)

    sensitivity = 1.0 / len(harmful_loader)
    noise_scale = sensitivity * np.sqrt(2 * np.log(1.25 / delta)) / epsilon

    for inputs, labels in tqdm(harmful_loader, desc="Applying DP Influence Unlearning"):
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = updated_model(inputs)
        loss = torch.nn.functional.cross_entropy(outputs, labels)
        grads = torch.autograd.grad(loss, updated_model.parameters(), retain_graph=True)

        with torch.no_grad():
            for param, grad in zip(updated_model.parameters(), grads):
                noise = torch.normal(mean=0, std=noise_scale, size=grad.shape, device=device)
                param -= (grad + noise)

    return updated_model

dp_model = dp_influence_unlearning(model_before_mitigating, train_loader, harmful_indices, epsilon=1.0, delta=1e-5)
wga, group_accs = evaluate_worst_group_accuracy(dp_model, val_loader, group_inds, device="cuda")