In [57]:
import os
import copy
import random
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.nn.utils.prune as prune

# For SNIP we need functional for loss computation
import torch.nn.functional as F


# Set random seed for reproducibility
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

device = "cuda" if torch.cuda.is_available() else "cpu"

# Path to the PACS dataset root directory (adjust this as necessary)
DATA_DIR = "/home/rishabh/Anuj_Sem6/CV_S6/DomainBed/domainbed/data"  # Update with your actual dataset directory

# Import DomainBed utilities (ensure DomainBed is in your Python path)
from domainbed import algorithms, datasets, hparams_registry


In [59]:
def load_algorithm_model(model_path):
    """
    Load a DomainBed model from model_path.
    Returns the instantiated algorithm and dump information.
    """
    dump = torch.load(model_path, map_location=device)
    algo_name = dump["args"]["algorithm"]
    AlgorithmClass = algorithms.get_algorithm_class(algo_name)
    algorithm = AlgorithmClass(
        dump["model_input_shape"],
        dump["model_num_classes"],
        dump["model_num_domains"],
        dump["model_hparams"]
    )
    algorithm.load_state_dict(dump["model_dict"])
    algorithm.to(device)
    algorithm.eval()
    return algorithm, dump

def evaluate_model_on_pacs(algorithm, target_env):
    """
    Evaluate the algorithm on all PACS domains.
    Returns:
      - target_acc: Accuracy on the held-out (target) domain (test split).
      - avg_source_acc: Average accuracy over the remaining (source) domains (test splits).
    """
    # Get hyperparameters for PACS – disable augmentation for evaluation.
    hparams = hparams_registry.default_hparams(algorithm.__class__.__name__, "PACS")
    hparams["data_augmentation"] = False  
    pacs_dataset = datasets.PACS(root=DATA_DIR, test_envs=[target_env], hparams=hparams)
    
    accuracies = {}
    for env_idx, env_dataset in enumerate(pacs_dataset):
        loader = DataLoader(env_dataset, batch_size=hparams["batch_size"], shuffle=False)
        total_correct = 0
        total_samples = 0
        for X, y in loader:
            X = X.to(device)
            y = y.to(device)
            with torch.no_grad():
                # Use the algorithm's predict (or call its network) to get logits.
                preds = algorithm.predict(X)
                preds = preds.argmax(dim=1)
            total_correct += (preds == y).sum().item()
            total_samples += len(y)
        acc = total_correct / total_samples
        accuracies[env_idx] = acc
    target_acc = accuracies[target_env]
    source_accs = [acc for env, acc in accuracies.items() if env != target_env]
    avg_source_acc = sum(source_accs) / len(source_accs)
    return target_acc, avg_source_acc

def get_sample_batch(algorithm_name, target_domain):
    """
    Get a single mini-batch from the target domain’s test dataset.
    This is used for computing gradients for the SNIP pruning method.
    """
    hparams = hparams_registry.default_hparams(algorithm_name, "PACS")
    hparams["data_augmentation"] = False
    pacs_dataset = datasets.PACS(root=DATA_DIR, test_envs=[target_domain], hparams=hparams)
    # Note: pacs_dataset is a list of domain datasets; use the target domain.
    loader = DataLoader(pacs_dataset[int(target_domain)], batch_size=hparams["batch_size"], shuffle=True)
    return next(iter(loader))


In [60]:
def apply_unstructured_pruning(model, prune_level):
    """
    Apply global unstructured L1-based weight magnitude pruning.
    Prunes individual weights from Conv2d and Linear layers.
    """
    params_to_prune = []
    for _, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            params_to_prune.append((module, 'weight'))
    prune.global_unstructured(
        params_to_prune, 
        pruning_method=prune.L1Unstructured, 
        amount=prune_level
    )
    for module, param_name in params_to_prune:
        prune.remove(module, param_name)
    return model

def apply_structured_pruning(model, prune_level):
    """
    Apply structured channel pruning on Conv2d layers.
    Uses the L1-norm of the filter weights to prune entire output channels.
    """
    for _, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(module, name='weight', amount=prune_level, n=1, dim=0)
            prune.remove(module, 'weight')
    return model

def apply_random_pruning(model, prune_level):
    """
    Apply global unstructured random pruning on Conv2d and Linear layers.
    """
    params_to_prune = []
    for _, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            params_to_prune.append((module, 'weight'))
    prune.global_unstructured(
        params_to_prune, 
        pruning_method=prune.RandomUnstructured, 
        amount=prune_level
    )
    for module, param_name in params_to_prune:
        prune.remove(module, param_name)
    return model

def apply_snip_pruning(model, prune_level, sample_batch):
    """
    Apply a SNIP-inspired pruning method.
    Computes saliency scores for each weight as |w * grad| using one mini-batch.
    Prunes the bottom prune_level fraction of weights globally.
    
    sample_batch is a tuple (inputs, labels) from the target domain.
    """
    # Set the model to train mode and zero gradients.
    model.train()
    inputs, labels = sample_batch
    inputs = inputs.to(device)
    labels = labels.to(device)
    model.zero_grad()
    
    # Forward pass and compute loss (use cross entropy).
    logits = model.predict(inputs)
    loss = F.cross_entropy(logits, labels)
    loss.backward()
    
    # Gather sensitivity scores from eligible layers.
    all_scores = []
    modules_and_names = []
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear):
            if module.weight.grad is None:
                continue
            score = (module.weight * module.weight.grad).abs()
            all_scores.append(score.view(-1))
            modules_and_names.append((module, 'weight'))
    all_scores = torch.cat(all_scores)
    
    # Determine threshold so that only (1 - prune_level) fraction of weights are kept.
    num_params_to_keep = int((1 - prune_level) * all_scores.numel())
    if num_params_to_keep < 1:
        threshold = all_scores.max() + 1
    else:
        # kthvalue returns the k-th smallest value.
        threshold, _ = torch.kthvalue(all_scores, all_scores.numel() - num_params_to_keep + 1)
    
    # Apply mask on each eligible parameter.
    for (module, param_name) in modules_and_names:
        score = (module.weight * module.weight.grad).abs()
        mask = (score >= threshold).float()
        module.weight.data.mul_(mask)
    
    # Clear gradients and return to eval mode.
    model.zero_grad()
    model.eval()
    return model


In [63]:
def prune_all_methods(algorithm_name, target_domain, prune_levels):
    """
    For the specified algorithm and target domain, load the pre-trained model,
    report baseline performance, and then apply all four pruning methods at each given prune level.
    
    The performance is evaluated using the test splits of all domains.
    Pruned models are saved under:
      pruned/{algorithm_name}/{target_domain}/{method}_{prune_percent}/model.pkl
    """
    # Load the original model and dump info.
    model_path = os.path.join("plain", algorithm_name, str(target_domain), "model.pkl")
    model, dump_info = load_algorithm_model(model_path)
    
    # Report baseline (no pruning) performance.
    baseline_target_acc, baseline_source_acc = evaluate_model_on_pacs(model, int(target_domain))
    print(f"Baseline [No Pruning] -> {algorithm_name} Domain: {target_domain} | "
          f"Target Acc: {baseline_target_acc:.2%}, Source Avg Acc: {baseline_source_acc:.2%}")
    
    # Define the pruning methods.
    pruning_methods = ["unstructured", "structured", "random", "snip"]
    
    for method in pruning_methods:
        for level in prune_levels:
            model_to_prune = copy.deepcopy(model)
            
            if method == "unstructured":
                model_to_prune = apply_unstructured_pruning(model_to_prune, level)
            elif method == "structured":
                model_to_prune = apply_structured_pruning(model_to_prune, level)
            elif method == "random":
                model_to_prune = apply_random_pruning(model_to_prune, level)
            elif method == "snip":
                # For SNIP, obtain one mini-batch from the target domain.
                sample_batch = get_sample_batch(algorithm_name, target_domain)
                model_to_prune = apply_snip_pruning(model_to_prune, level, sample_batch)
            else:
                raise ValueError("Unknown pruning method. Choose one of the supported methods.")
            
            # Evaluate the pruned model.
            target_acc, avg_source_acc = evaluate_model_on_pacs(model_to_prune, int(target_domain))
            print(f"[{method.capitalize()} {int(level*100)}% Pruned] {algorithm_name} Domain: {target_domain} | "
                  f"Target Acc: {target_acc:.2%}, Source Avg Acc: {avg_source_acc:.2%}")
            
            # Save the pruned model.
            save_dir = os.path.join("pruned", algorithm_name, str(target_domain), f"{method}_{int(level*100)}")
            os.makedirs(save_dir, exist_ok=True)
            save_path = os.path.join(save_dir, "model.pkl")
            pruned_dump = {
                "args": dump_info["args"],
                "model_input_shape": dump_info["model_input_shape"],
                "model_num_classes": dump_info["model_num_classes"],
                "model_num_domains": dump_info["model_num_domains"],
                "model_hparams": dump_info["model_hparams"],
                "model_dict": model_to_prune.state_dict()
            }
            torch.save(pruned_dump, save_path)


### URM

In [64]:
# List all target domains (for PACS, typical target domains are 0, 1, 2, 3)
target_domains = [0, 1, 2, 3]

# Define your algorithm and prune levels.
algorithm_name = "URM"            # e.g., "ERM", "IRM", etc.
prune_levels = [0.1, 0.3, 0.5, 0.7]  # Pruning percentages: 10%, 30%, 50%, 70%

for t_domain in target_domains:
    print(f"\n===== Processing Target Domain {t_domain} for {algorithm_name} =====")
    prune_all_methods(algorithm_name, t_domain, prune_levels)



===== Processing Target Domain 0 for URM =====
--> Initializing discriminator <--
--> Modifying encoder output: tanh
Baseline [No Pruning] -> URM Domain: 0 | Target Acc: 80.96%, Source Avg Acc: 99.28%
[Unstructured 10% Pruned] URM Domain: 0 | Target Acc: 80.91%, Source Avg Acc: 99.27%
[Unstructured 30% Pruned] URM Domain: 0 | Target Acc: 80.57%, Source Avg Acc: 99.19%
[Unstructured 50% Pruned] URM Domain: 0 | Target Acc: 79.74%, Source Avg Acc: 99.03%
[Unstructured 70% Pruned] URM Domain: 0 | Target Acc: 71.34%, Source Avg Acc: 98.20%
[Structured 10% Pruned] URM Domain: 0 | Target Acc: 52.15%, Source Avg Acc: 83.10%
[Structured 30% Pruned] URM Domain: 0 | Target Acc: 24.12%, Source Avg Acc: 21.15%
[Structured 50% Pruned] URM Domain: 0 | Target Acc: 21.92%, Source Avg Acc: 15.74%
[Structured 70% Pruned] URM Domain: 0 | Target Acc: 21.92%, Source Avg Acc: 15.74%
[Random 10% Pruned] URM Domain: 0 | Target Acc: 22.75%, Source Avg Acc: 26.67%
[Random 30% Pruned] URM Domain: 0 | Target Acc:

### EQRM

In [56]:
# List all target domains (for PACS, typical target domains are 0, 1, 2, 3)
target_domains = [0, 1, 2, 3]

# Define your algorithm and prune levels.
algorithm_name = "EQRM"            # e.g., "ERM", "IRM", etc.
prune_levels = [0.1, 0.3, 0.5, 0.7]  # Pruning percentages: 10%, 30%, 50%, 70%

for t_domain in target_domains:
    print(f"\n===== Processing Target Domain {t_domain} for {algorithm_name} =====")
    prune_all_methods(algorithm_name, t_domain, prune_levels)



===== Processing Target Domain 0 for EQRM =====
Baseline [No Pruning] -> EQRM Domain: 0 | Target Acc: 85.60%, Source Avg Acc: 99.19%
[Unstructured 10% Pruned] EQRM Domain: 0 | Target Acc: 85.74%, Source Avg Acc: 99.21%
[Unstructured 30% Pruned] EQRM Domain: 0 | Target Acc: 85.40%, Source Avg Acc: 99.17%
[Unstructured 50% Pruned] EQRM Domain: 0 | Target Acc: 83.94%, Source Avg Acc: 98.87%
[Unstructured 70% Pruned] EQRM Domain: 0 | Target Acc: 64.70%, Source Avg Acc: 95.47%
[Structured 10% Pruned] EQRM Domain: 0 | Target Acc: 46.04%, Source Avg Acc: 70.53%
[Structured 30% Pruned] EQRM Domain: 0 | Target Acc: 10.25%, Source Avg Acc: 13.98%
[Structured 50% Pruned] EQRM Domain: 0 | Target Acc: 8.98%, Source Avg Acc: 10.79%
[Structured 70% Pruned] EQRM Domain: 0 | Target Acc: 8.98%, Source Avg Acc: 10.79%
[Random 10% Pruned] EQRM Domain: 0 | Target Acc: 24.02%, Source Avg Acc: 40.46%
[Random 30% Pruned] EQRM Domain: 0 | Target Acc: 7.81%, Source Avg Acc: 11.69%
[Random 50% Pruned] EQRM Doma

### ERM++

In [48]:
# List all target domains (for PACS, typical target domains are 0, 1, 2, 3)
target_domains = [0, 1, 2, 3]

# Define your algorithm and prune levels.
algorithm_name = "ERMPlusPlus"            # e.g., "ERM", "IRM", etc.
prune_levels = [0.1, 0.3, 0.5, 0.7]  # Pruning percentages: 10%, 30%, 50%, 70%

for t_domain in target_domains:
    print(f"\n===== Processing Target Domain {t_domain} for {algorithm_name} =====")
    prune_all_methods(algorithm_name, t_domain, prune_levels)



===== Processing Target Domain 0 for ERMPlusPlus =====
Baseline [No Pruning] -> ERMPlusPlus Domain: 0 | Target Acc: 88.62%, Source Avg Acc: 99.64%
[Unstructured 10% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 88.48%, Source Avg Acc: 99.64%
[Unstructured 30% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 88.33%, Source Avg Acc: 99.65%
[Unstructured 50% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 86.38%, Source Avg Acc: 99.63%
[Unstructured 70% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 73.19%, Source Avg Acc: 99.15%
[Structured 10% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 63.38%, Source Avg Acc: 90.31%
[Structured 30% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 15.82%, Source Avg Acc: 15.85%
[Structured 50% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 14.84%, Source Avg Acc: 11.16%
[Structured 70% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 12.45%, Source Avg Acc: 16.81%
[Random 10% Pruned] ERMPlusPlus Domain: 0 | Target Acc: 22.46%, Source Avg Acc: 54.88%
[Random 30% Pruned] ERMPlusPl

### IRM

In [40]:
# List all target domains (for PACS, typical target domains are 0, 1, 2, 3)
target_domains = [0, 1, 2, 3]

# Define your algorithm and prune levels.
algorithm_name = "IRM"            # e.g., "ERM", "IRM", etc.
prune_levels = [0.1, 0.3, 0.5, 0.7]  # Pruning percentages: 10%, 30%, 50%, 70%

for t_domain in target_domains:
    print(f"\n===== Processing Target Domain {t_domain} for {algorithm_name} =====")
    prune_all_methods(algorithm_name, t_domain, prune_levels)



===== Processing Target Domain 0 for IRM =====




Baseline [No Pruning] -> IRM Domain: 0 | Target Acc: 32.42%, Source Avg Acc: 61.85%
[Unstructured 10% Pruned] IRM Domain: 0 | Target Acc: 32.28%, Source Avg Acc: 61.96%
[Unstructured 30% Pruned] IRM Domain: 0 | Target Acc: 33.06%, Source Avg Acc: 62.34%
[Unstructured 50% Pruned] IRM Domain: 0 | Target Acc: 33.30%, Source Avg Acc: 63.60%
[Unstructured 70% Pruned] IRM Domain: 0 | Target Acc: 19.63%, Source Avg Acc: 35.68%
[Structured 10% Pruned] IRM Domain: 0 | Target Acc: 9.81%, Source Avg Acc: 13.21%
[Structured 30% Pruned] IRM Domain: 0 | Target Acc: 14.40%, Source Avg Acc: 10.36%
[Structured 50% Pruned] IRM Domain: 0 | Target Acc: 14.40%, Source Avg Acc: 10.36%
[Structured 70% Pruned] IRM Domain: 0 | Target Acc: 14.40%, Source Avg Acc: 10.36%
[Random 10% Pruned] IRM Domain: 0 | Target Acc: 8.98%, Source Avg Acc: 10.79%
[Random 30% Pruned] IRM Domain: 0 | Target Acc: 10.89%, Source Avg Acc: 11.50%
[Random 50% Pruned] IRM Domain: 0 | Target Acc: 8.98%, Source Avg Acc: 10.79%
[Random 70

### ERM

In [32]:
# List all target domains (for PACS, typical target domains are 0, 1, 2, 3)
target_domains = [0, 1, 2, 3]

# Define your algorithm and prune levels.
algorithm_name = "ERM"            # e.g., "ERM", "IRM", etc.
prune_levels = [0.1, 0.3, 0.5, 0.7]  # Pruning percentages: 10%, 30%, 50%, 70%

for t_domain in target_domains:
    print(f"\n===== Processing Target Domain {t_domain} for {algorithm_name} =====")
    prune_all_methods(algorithm_name, t_domain, prune_levels)



===== Processing Target Domain 0 for ERM =====
Baseline [No Pruning] -> ERM Domain: 0 | Target Acc: 84.52%, Source Avg Acc: 99.33%
[Unstructured 10% Pruned] ERM Domain: 0 | Target Acc: 84.47%, Source Avg Acc: 99.33%
[Unstructured 30% Pruned] ERM Domain: 0 | Target Acc: 84.33%, Source Avg Acc: 99.36%
[Unstructured 50% Pruned] ERM Domain: 0 | Target Acc: 82.71%, Source Avg Acc: 99.24%
[Unstructured 70% Pruned] ERM Domain: 0 | Target Acc: 68.80%, Source Avg Acc: 97.75%
[Structured 10% Pruned] ERM Domain: 0 | Target Acc: 45.41%, Source Avg Acc: 87.66%
[Structured 30% Pruned] ERM Domain: 0 | Target Acc: 17.43%, Source Avg Acc: 15.55%
[Structured 50% Pruned] ERM Domain: 0 | Target Acc: 8.69%, Source Avg Acc: 11.37%
[Structured 70% Pruned] ERM Domain: 0 | Target Acc: 8.98%, Source Avg Acc: 10.79%
[Random 10% Pruned] ERM Domain: 0 | Target Acc: 18.46%, Source Avg Acc: 44.83%
[Random 30% Pruned] ERM Domain: 0 | Target Acc: 10.89%, Source Avg Acc: 12.10%
[Random 50% Pruned] ERM Domain: 0 | Targ

In [None]:
# import glob
# import torch
# import os
# import pandas as pd

# def model_stats(model):
#     """
#     Compute the total number of parameters, number of nonzero parameters,
#     and the estimated size (in bytes) of the model.
#     """
#     total_params = sum(p.numel() for p in model.parameters())
#     nonzero_params = sum(torch.count_nonzero(p).item() for p in model.parameters())
#     size_bytes = sum(p.numel() * p.element_size() for p in model.parameters())
#     return total_params, nonzero_params, size_bytes

# # Specify your algorithm name and target domain (as used in folder names)
# algorithm_name = "ERM"  # Change if needed
# target_domain = "0"     # Folder name as string

# # Root directory where pruned models are stored.
# pruned_root = os.path.join("pruned", algorithm_name, target_domain)

# # Define the pruning methods you have
# pruning_methods = ["unstructured", "structured", "random"]

# results = []

# # Loop over methods and all prune levels for each method
# for method in pruning_methods:
#     # The pattern will match any directory like 'unstructured_10', 'structured_30', etc.
#     pattern = os.path.join(pruned_root, f"{method}_*", "model.pkl")
#     for model_path in glob.glob(pattern):
#         # Extract pruning level percentage from directory name.
#         # E.g., if model_path is pruned/ERM/0/unstructured_10/model.pkl, then '10' is extracted.
#         parts = model_path.split(os.path.sep)
#         pruning_dir = parts[-2]           # e.g., "unstructured_10"
#         prune_level_str = pruning_dir.split("_")[1]  # "10"
#         prune_level = int(prune_level_str) / 100.0   # Convert to fraction (e.g., 0.1)

#         # Load the pruned model
#         model, _ = load_algorithm_model(model_path)
#         total_params, nonzero_params, size_bytes = model_stats(model)

#         results.append({
#             "method": method,
#             "prune_level": prune_level,
#             "total_params": total_params,
#             "nonzero_params": nonzero_params,
#             "model_size_MB": size_bytes / (1024 ** 2)
#         })

# # Create and display a summary table using pandas DataFrame
# df = pd.DataFrame(results)
# print(df)




          method  prune_level  total_params  nonzero_params  model_size_MB
0   unstructured          0.1      23522375        21175450       89.73074
1   unstructured          0.5      23522375        11787751       89.73074
2   unstructured          0.7      23522375         7093901       89.73074
3   unstructured          0.3      23522375        16481601       89.73074
4     structured          0.7      23522375         7111280       89.73074
5     structured          0.1      23522375        21172949       89.73074
6     structured          0.5      23522375        11794919       89.73074
7     structured          0.3      23522375        16478558       89.73074
8         random          0.3      23522375        16481601       89.73074
9         random          0.1      23522375        21175450       89.73074
10        random          0.5      23522375        11787751       89.73074
11        random          0.7      23522375         7093901       89.73074
