In [1]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import torch
from torch.utils.data.dataset import Subset
from torch.utils.data import DataLoader

# Allow Subset class for unpickling
torch.serialization.add_safe_globals([Subset])

# Load the datasets
test_subset = torch.load("data/cifar100_selected_test.pt", weights_only=False)
val_subset  = torch.load("data/cifar100_extended_val.pt", weights_only=False)

testloader = DataLoader(test_subset, batch_size=1, shuffle=False)
valloader  = DataLoader(val_subset,  batch_size=32, shuffle=True)

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

def compute_grad_cosine_similarity(source_model, target_models, loader, device='cuda', num_samples=1000):
    """
    Compute average gradient cosine similarity between a source model and several target models.
    """

    source_model.eval().to(device)
    for tm in target_models:
        tm.eval().to(device)

    cos = nn.CosineSimilarity(dim=1)
    criterion = nn.CrossEntropyLoss()

    # storage for mean cosine per target model
    sim_sums = torch.zeros(len(target_models), device=device)
    count = 0

    for imgs, labels in tqdm(loader, total=min(len(loader), num_samples // loader.batch_size), desc="Grad-Cos Sim"):
        imgs, labels = imgs.to(device), labels.to(device)
        batch_size = imgs.size(0)
        if count >= num_samples:
            break

        imgs.requires_grad_(True)

        # --- Source model gradient ---
        outputs_src = source_model(imgs)
        loss_src = criterion(outputs_src, labels)
        grad_src = torch.autograd.grad(loss_src, imgs, retain_graph=True)[0]
        grad_src = grad_src.view(batch_size, -1)
        grad_src = grad_src / (grad_src.norm(dim=1, keepdim=True) + 1e-8)

        # --- Target models gradients ---
        for i, tm in enumerate(target_models):
            imgs.grad = None  # clear gradient for next model
            outputs_tgt = tm(imgs)
            loss_tgt = criterion(outputs_tgt, labels)
            grad_tgt = torch.autograd.grad(loss_tgt, imgs, retain_graph=True)[0]
            grad_tgt = grad_tgt.view(batch_size, -1)
            grad_tgt = grad_tgt / (grad_tgt.norm(dim=1, keepdim=True) + 1e-8)

            # cosine similarity per sample
            sim_batch = torch.sum(grad_src * grad_tgt, dim=1)
            sim_sums[i] += sim_batch.mean()

        count += batch_size
        if count >= num_samples:
            break

    # average similarities
    sim_means = (sim_sums / (count / loader.batch_size)).detach().cpu()
    result = {f"model_{i}": sim_means[i].item() for i in range(len(target_models))}
    return result

In [5]:
import model_helper as helper

device = 'cuda:2' if torch.cuda.is_available() else 'cpu' 

target_models_args = [
    "resnetv2_101x1_bitm", 
    "resnet152", 
    "regnety_160", 
    "vit_base_patch16_224", 
    "deit_base_patch16_224", 
    "swin_base_patch4_window7_224", 
    "convmixer_768_32"
    ]

target_models = [] 

for i in target_models_args: 
    model = helper.load_model_hub(i)
    model = model.to(device)
    target_models.append(model.eval())


ðŸ”¹ Loading vit_base_patch16_224 from Models/target/vit_base.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading deit_base_patch16_224 from Models/target/deit_base.pth.tar
Load result: <All keys matched successfully>

ðŸ”¹ Loading swin_base_patch4_window7_224 from Models/target/swin_base.pth.tar
Load result: <All keys matched successfully>


In [16]:
import model_helper as helper


ens_models_args = [
    "resnet18", 
    "inception_v3", 
    "deit_tiny_patch16_224", 
    "vit_tiny_patch16_224", 
    "efficientnet_b0", 
    "gcvit_tiny"
    ]

model_name = ens_models_args[5]


model = helper.load_model_hub(model_name)
model = model.to(device)


In [17]:
compute_grad_cosine_similarity(model, target_models, testloader, device=device, num_samples=1000)  

Grad-Cos Sim: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š| 999/1000 [09:22<00:00,  1.78it/s]


{'model_0': 0.11069872230291367,
 'model_1': 0.05222869664430618,
 'model_2': 0.09057532995939255,
 'model_3': 0.07674083113670349,
 'model_4': 0.13004176318645477,
 'model_5': 0.10210533440113068,
 'model_6': 0.04814634844660759}

In [None]:
{'model_0': 0.1027129516005516,
 'model_1': 0.03547651320695877,
 'model_2': 0.08607926219701767,
 'model_3': 0.049814414232969284,
 'model_4': 0.07185643911361694,
 'model_5': 0.062440186738967896,
 'model_6': 0.047719795256853104}