In [1]:
%load_ext autoreload
%autoreload 2

import os

os.chdir(globals()["_dh"][0])
os.chdir("..")

In [2]:
import models_mae_hetero
import torch

from torchvision import transforms
from torchvision.transforms import InterpolationMode
from torch.utils.tensorboard import SummaryWriter

from gaussian_noise import AddGaussianNoise
from dataset_classes.pretrain_csi_5g import CSI5G
from dataset_classes.pretrain_csi_wifi import CSIWiFi
from dataset_classes.spectrogram_images import SpectrogramImages

from torch.utils.data import DataLoader, RandomSampler
import torch.nn.functional as F

import util.misc as misc
from util.misc import NativeScalerWithGradNormCount as NativeScaler

import timm.optim.optim_factory as optim_factory

import models_mae_hetero
from engine_pretrain_hetero import train_one_epoch

from tqdm import tqdm

import torch_pruning as tp

import timm
import copy

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = 'mae_vit_small_patch16'
device = 'cuda' if torch.cuda.is_available() else 'cpu'


# define the model
model = models_mae_hetero.__dict__[model](norm_pix_loss=False, in_chans=[1, 3, 4])
model.to(device)

model_without_ddp = model
print("Model = %s" % str(model_without_ddp))

Model = MaskedAutoencoderViT(
  (patch_embed): ModuleList(
    (0): PatchEmbed(
      (proj): Conv2d(1, 512, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (1): PatchEmbed(
      (proj): Conv2d(3, 512, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (2): PatchEmbed(
      (proj): Conv2d(4, 512, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
  )
  (blocks): ModuleList(
    (0-11): 12 x Block(
      (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (attn): Attention(
        (qkv): Linear(in_features=512, out_features=1536, bias=True)
        (q_norm): Identity()
        (k_norm): Identity()
        (attn_drop): Dropout(p=0.0, inplace=False)
        (proj): Linear(in_features=512, out_features=512, bias=True)
        (proj_drop): Dropout(p=0.0, inplace=False)
      )
      (ls1): Identity()
      (drop_path1): Identity()
      (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
     

In [4]:
# model.patch_embed = model.patch_embed[0]

In [5]:
# model.decoder_pred = model.decoder_pred[0]
# model

In [4]:
augmentation = True

data_path = ['/home/ict317-3/Mohammad/Tiny-WFMs/pretraining_datasets/spectrogram_dataset',
             '/home/ict317-3/Mohammad/Tiny-WFMs/pretraining_datasets/spectrogram_iqengine_dataset',
             '/home/ict317-3/Mohammad/Tiny-WFMs/pretraining_datasets/5G_CFR',
             '/home/ict317-3/Mohammad/Tiny-WFMs/pretraining_datasets/NTU-Fi-HumanID']

log_dir = './output_dir'

transform_train = transforms.Compose([
        transforms.functional.pil_to_tensor,
        transforms.Lambda(lambda x: 10 * torch.log10(x + 1e-12)),
        transforms.Lambda(lambda x: (x + 120) / (-0.5 + 120)),
        transforms.Resize((224, 224), antialias=True,
                          interpolation=InterpolationMode.BICUBIC),  # Resize
        transforms.Normalize(mean=[0.451], std=[0.043])  # Normalize
    ])

dataset_train_one = SpectrogramImages(data_path[:-2], transform=transform_train)

augment_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    AddGaussianNoise(mean=0.0, std=0.05)]
)

if augmentation:
        dataset_train_two = CSI5G(data_path[-2], augment_transforms=augment_transforms)
        dataset_train_three = CSIWiFi(data_path[-1], augment_transforms=augment_transforms)
else:
    dataset_trainlog_dir_two = CSI5G(data_path[-2])
    dataset_train_three = CSIWiFi(data_path[-1])

print(dataset_train_one, dataset_train_two, dataset_train_three)

sampler_train_one = RandomSampler(dataset_train_one)
sampler_train_two = RandomSampler(dataset_train_two)
sampler_train_three = RandomSampler(dataset_train_three)

os.makedirs(log_dir, exist_ok=True)
log_writer = SummaryWriter(log_dir=log_dir)



<dataset_classes.spectrogram_images.SpectrogramImages object at 0x7e8ba6dc0dd0> <dataset_classes.pretrain_csi_5g.CSI5G object at 0x7e8b6f00a5d0> <dataset_classes.pretrain_csi_wifi.CSIWiFi object at 0x7e8b6f00a690>


In [5]:
batch_size = 16
num_workers = 10
pin_mem = True
csi_subsampling = False

data_loader_train_one = DataLoader(
        dataset_train_one, sampler=sampler_train_one,
        batch_size= batch_size,
        num_workers= num_workers,
        pin_memory= pin_mem,
        drop_last=True)

if  csi_subsampling:
    data_loader_train_two = DataLoader(
        dataset_train_two, sampler=sampler_train_two,
        batch_size= batch_size // 2,
        num_workers= num_workers,
        pin_memory= pin_mem,
        drop_last=True)

    data_loader_train_three = DataLoader(
        dataset_train_three, sampler=sampler_train_three,
        batch_size= batch_size // 2,
        num_workers= num_workers,
        pin_memory= pin_mem,
        drop_last=True)
else:
    data_loader_train_two = DataLoader(
        dataset_train_two, sampler=sampler_train_two,
        batch_size= batch_size,
        num_workers= num_workers,
        pin_memory= pin_mem,
        drop_last=True)

    data_loader_train_three = DataLoader(
        dataset_train_three, sampler=sampler_train_three,
        batch_size= batch_size,
        num_workers= num_workers,
        pin_memory= pin_mem,
        drop_last=True)


In [6]:
accum_iter = 1
lr = None
blr = 1e-3


eff_batch_size =  batch_size *accum_iter
    
if lr is None:  # only base_lr is specified
     lr = blr * eff_batch_size / 256

print("base lr: %.2e" % ( lr * 256 / eff_batch_size))
print("actual lr: %.2e" %  lr)

print("accumulate grad iterations: %d" %accum_iter)
print("effective batch size: %d" % eff_batch_size)


base lr: 1.00e-03
actual lr: 6.25e-05
accumulate grad iterations: 1
effective batch size: 16


In [7]:
weight_decay = 0.05

# following timm: set wd as 0 for bias and norm layers
param_groups = optim_factory.param_groups_weight_decay(model_without_ddp, weight_decay)
optimizer = torch.optim.AdamW(param_groups, lr=lr, betas=(0.9, 0.95))
print(optimizer)
loss_scaler = NativeScaler()

ckpt_path = '/home/ict317-3/Mohammad/Tiny-WFMs/checkpoints/pretrained_all_data.pth'
pretrained = torch.load(ckpt_path, map_location=device, weights_only=False)['model']
model.load_state_dict(pretrained, strict=False)


AdamW (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 6.25e-05
    maximize: False
    weight_decay: 0.0

Parameter Group 1
    amsgrad: False
    betas: (0.9, 0.95)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 6.25e-05
    maximize: False
    weight_decay: 0.05
)


  self._scaler = torch.cuda.amp.GradScaler()


<All keys matched successfully>

In [10]:
# def compute_baseline_performance(model, mask_ratio, kernel_size, dataset_id, device):

#     accuracy = 0
#     total_loss = 0

#     model.eval()

#     model = model.to(device)
        
#     data = {
#         1: (dataset_train_one,   data_loader_train_one),
#         2: (dataset_train_two,   data_loader_train_two),
#         3: (dataset_train_three, data_loader_train_three),
#     }
#     assert dataset_id in data, f"dataset_id must be 1, 2, or 3, got {dataset_id}"


#     with torch.no_grad():
       
#         for k, (images, _) in enumerate(tqdm(data[dataset_id][1], desc="Batches", leave=False)):
#             images = images.to(device)
#             loss, reconstructed, mask = model(images, mask_ratio=mask_ratio / 100)
#             images = torch.einsum('nchw->nhwc', images)
#             reconstructed = torch.einsum('nchw->nhwc', model.unpatchify(reconstructed))
#             mask = model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, 16 ** 2 * 1))
#             mask = torch.einsum('nchw->nhwc', mask)
#             reconstructed = (1 - mask) * images + mask * reconstructed

#             pooled_images = F.avg_pool2d(images.permute(0, 3, 1, 2), kernel_size=kernel_size, stride=kernel_size)
#             pooled_reconstructed = F.avg_pool2d(reconstructed.permute(0, 3, 1, 2), kernel_size=kernel_size, stride=kernel_size)
#             mu, std = torch.mean(pooled_images, dim=(1, 2, 3)), torch.std(pooled_images, dim=(1, 2, 3))
#             threshold = mu + 0.5 * std
#             threshold = threshold.view(-1, 1, 1, 1).repeat((1, 1, pooled_images.shape[2], pooled_images.shape[3]))
#             pooled_images = pooled_images > threshold
#             pooled_reconstructed = pooled_reconstructed > threshold

#             accuracy += (pooled_images == pooled_reconstructed).sum().item()
#             total_loss += loss.item() 

#     accuracy /= (len(data[dataset_id][0]) * (224 // kernel_size) ** 2)
#     avg_loss = total_loss / len(data[dataset_id][1])

#     print(accuracy)
#     return avg_loss

In [8]:
def compute_loss_per_loader(model, device, data_loader, mask_ratio=75, kernel_size=3):
    model.eval()
    model.to(device)

    # map id -> (dataset, loader)
    # data = {
    #     1: data_loader_train_one,
    #     2: data_loader_train_two,
    #     3: data_loader_train_three,
    # }
    # assert dataset_id in data, f"dataset_id must be 1, 2, or 3, got {dataset_id}"

    # loader = data[dataset_id]

    total_equal = 0
    total_cells_seen = 0
    total_loss = 0.0

    # Try to infer patch size and channels for robust unpatchify of the mask
    # Fallback to 16 and 1 if the model doesn't expose it
    try:
        p = getattr(model, "patch_embed").patch_size[0]  # e.g., 16
    except Exception:
        p = 16
    # We'll infer channels per batch from the images

    with torch.no_grad():
        for images, _ in tqdm(data_loader, desc="Batches", leave=False):
            images = images.to(device)  # NCHW expected here from your dataset pipeline

            # forward
            loss, reconstructed, mask = model(images, mask_ratio=mask_ratio / 100.0)

            # Unpatchify outputs (assumes MAE-like shapes)
            # reconstructed: (N, L, p*p*C), mask: (N, L)
            # Infer channels from the reconstructed payload if possible
            # C = (reconstructed.shape[-1] // (p * p)) if reconstructed.ndim == 3 else images.shape[1]
            if reconstructed.ndim == 3:
                C = reconstructed.shape[-1] // (p * p)
            else:
                C = images.shape[1]

            reconstructed = model.unpatchify(reconstructed)  # -> (N, C, H, W)
            # Broadcast mask from (N, L) to (N, p*p*C, L) before unpatchify
            mask_expanded = mask.unsqueeze(-1).repeat(1, 1, p * p * C)
            mask_img = model.unpatchify(mask_expanded)  # (N, C, H, W)

            # Blend original and reconstructed only on masked regions
            blended = (1 - mask_img) * images + mask_img * reconstructed  # all NCHW

            # Pool down to grids
            pooled_images = F.avg_pool2d(images, kernel_size=kernel_size, stride=kernel_size)
            pooled_blended = F.avg_pool2d(blended, kernel_size=kernel_size, stride=kernel_size)

            # Per-sample thresholds from original pooled images
            # mean/std over (C,H,W)
            mu = pooled_images.mean(dim=(1, 2, 3))
            std = pooled_images.std(dim=(1, 2, 3), unbiased=False)  # avoid NaNs when only 1 cell
            threshold = (mu + 0.5 * std).view(-1, 1, 1, 1)

            # Binarize
            bin_images = pooled_images > threshold
            bin_blended = pooled_blended > threshold

            # Count equal cells
            equal_cells = (bin_images == bin_blended).sum().item()
            total_equal += equal_cells

            # Track how many cells we actually processed
            n, c, h, w = bin_images.shape
            total_cells_seen += n * c * h * w

            # Accumulate batch loss
            total_loss += float(loss.item())

    # Safe aggregations
    accuracy = total_equal / max(1, total_cells_seen)
    avg_loss = total_loss / max(1, len(data_loader))

    return avg_loss, accuracy


In [15]:
# mask_ratio = 75
# kernel_size = 3
# dataset_id = 3


# loss = compute_loss_per_loader(model, data_loader_train_one, device)
# print(loss)

In [9]:
def compute_baseline_performance(model, device, data_loaders):

    overall_loss = sum(compute_loss_per_loader(model, device, loader)[0] for loader in data_loaders) / len(data_loaders)
    
    return overall_loss

In [17]:
# def forward_encoder_new(self, x, mask_ratio):
#         # embed patches
#         x = self.patch_embed(x)

#         # add pos embed w/o cls token
#         x = x + self.pos_embed[:, 1:, :]

#         # masking: length -> length * mask_ratio
#         x, mask, ids_restore = self.random_masking(x, mask_ratio)

#         # append cls token
#         cls_token = self.cls_token + self.pos_embed[:, :1, :]
#         cls_tokens = cls_token.expand(x.shape[0], -1, -1)
#         x = torch.cat((cls_tokens, x), dim=1)

#         # apply Transformer blocks
#         for blk in self.blocks:
#             x = blk(x)
#         x = self.norm(x)

#         return x, mask, ids_restore


# def forward_decoder_new(self, x, ids_restore):
#         # embed tokens
#         x = self.decoder_embed(x)

#         # append mask tokens to sequence
#         mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
#         x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)  # no cls token
#         x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]))  # unshuffle
#         x = torch.cat([x[:, :1, :], x_], dim=1)  # append cls token

#         # add pos embed
#         x = x + self.decoder_pos_embed

#         # apply Transformer blocks
#         for blk in self.decoder_blocks:
#             x = blk(x)
#         x = self.decoder_norm(x)

#         # predictor projection
#         x = self.decoder_pred(x)

#         # remove cls token
#         x = x[:, 1:, :]

#         return x

# def forward_new(self, imgs, mask_ratio=0.75):
#     latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)  # <- renamed
#     pred = self.forward_decoder(latent, ids_restore)                    # <- renamed
#     loss = self.forward_loss(imgs, pred, mask)
#     return loss, pred, mask

# import types

# model.forward_encoder = types.MethodType(forward_encoder_new, model)
# model.forward_decoder = types.MethodType(forward_decoder_new, model)
# model.forward          = types.MethodType(forward_new, model)

In [None]:
# data_loaders = [data_loader_train_one, data_loader_train_two, data_loader_train_three]

# loss = compute_baseline_performance(model, 'cuda', data_loaders)

# loss

In [11]:

def get_blocks(model):
    ignored_blocks = []
    model_blocks = []
    num_heads = {}
    bottleneck = False

    for m in model.modules():
        if isinstance(m, timm.models.vision_transformer.Attention):
            num_heads[m.qkv] = m.num_heads
        if bottleneck and isinstance(m, timm.models.vision_transformer.Mlp):
            ignored_blocks.append(m.fc2)

    for name, layer in model.named_children():
        if name != 'blocks':
            if len(list(layer.children())) > 0:
                ignored_blocks.extend(layer.children())
            else:
                ignored_blocks.append(layer)
        else:
            if len(list(layer.children())) > 0:
                model_blocks.extend(layer.children())
            else:
                model_blocks.append(layer)

    return model_blocks, ignored_blocks, num_heads


In [13]:
from engine_pretrain_hetero import RoundRobinLoader

def selective_block_pruning(trained_model, device, pruning_ratios, prune_method, data_loaders, mask_ratio):
    model = copy.deepcopy(trained_model)
    for param in model.parameters():
        param.requires_grad = True

    model_blocks, ignored_blocks, num_heads = get_blocks(model)
    model.to(device)

    pruning_info = {
        i: {"block": model_blocks[i], "pruning_ratio": ratio}
        for i, ratio in enumerate(pruning_ratios)
    }

    if prune_method == 'channel_pruning_Taylor_importance':
        imp = tp.importance.GroupTaylorImportance()

        combined_iter = RoundRobinLoader(data_loaders)

        if isinstance(imp, tp.importance.GroupTaylorImportance):
            model.zero_grad()
            model.train(True)

            print("Accumulating gradients for pruning...")
            for data_iter_step, (key, samples, labels) in enumerate(combined_iter):
                samples, labels = samples.to(device), labels.to(device)
                if data_iter_step >= 60: # 20 samples of each dataloader for now
                    break
                loss, _, _ = model(samples, mask_ratio=mask_ratio / 100)
                loss.backward()

        original_macs, original_nparams = tp.utils.count_ops_and_params(model, samples)

        for i, info in pruning_info.items():
            pruning_ratio = info["pruning_ratio"]
            if pruning_ratio == 0:
                continue

            ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]
            combined_ignored_layers = ignored_blocks + ignored_layers_block
            # print('combined_ignored_layers:   ', combined_ignored_layers)
            
            count = 0
            while True:
                pruner = tp.pruner.MetaPruner(
                    model,
                    example_inputs=samples,
                    importance=imp,
                    pruning_ratio=pruning_ratio,
                    ignored_layers=combined_ignored_layers,
                    num_heads=num_heads,
                    prune_num_heads=False,
                    prune_head_dims=True
                )
                for g in pruner.step(interactive=True):
                    g.prune()

                # print(f'A pruning process has been performed here with a {pruning_ratio} pruning ratio ...  ')
                
                for m in model.modules():
                    if isinstance(m, timm.models.vision_transformer.Attention):
                        m.num_heads = pruner.num_heads[m.qkv]
                        m.head_dim = m.qkv.out_features // (3 * m.num_heads)

                macs, nparams = tp.utils.count_ops_and_params(model, samples)

                if original_nparams - nparams == 0:
                    count += 1
                    if count == 1:
                        pruning_ratio = 0.5
                    else:
                        break

                original_nparams = nparams

        del samples, labels
        torch.cuda.empty_cache()

    return model, macs, nparams

In [14]:
def perplexity_analysis_with_contributions(original_model, device, metric, measure_performance, data_loaders, mask_ratio=75):
    model_blocks, _, _ = get_blocks(original_model)
    blocks_number = len(model_blocks)

    total_block_performance = [0.0 for _ in range(blocks_number)]
    params_reduction = []
    macs_reduction = []

    original_model.to(device)

    print(f"Computing baseline {metric} without block replacement...")
    baseline_performance = measure_performance(original_model, device, data_loaders)
    print(f"Baseline {metric}: {baseline_performance}")

    example_inputs = next(iter(data_loaders[0]))[0].to(device)
    original_macs, original_nparams = tp.utils.count_ops_and_params(original_model, example_inputs)

    for block_idx in range(blocks_number):
        print(f"Replacing block {block_idx}")
        pruning_ratios = (np.eye(blocks_number) * 0.85)[block_idx]
        pruned_model, macs, nparams = selective_block_pruning(
            original_model, device, pruning_ratios, 'channel_pruning_Taylor_importance', data_loaders, mask_ratio
        )

        params_reduction.append((original_nparams - nparams)/original_nparams * 100)
        macs_reduction.append((original_macs - macs) / original_macs * 100)

        pruned_model.to(device)
        block_performance = measure_performance(pruned_model, device, data_loaders)
        total_block_performance[block_idx] = block_performance
        print(f'The {metric} after pruning this block is: {block_performance}')

    total_degradation_in_performance = 0.0
    block_degradation = []
    total_params_reduction = 0.0
    total_macs_reduction = 0.0

    for block_idx in range(blocks_number):
        degradation = np.abs(total_block_performance[block_idx] - baseline_performance)
        print(f"Degradation in {metric} is: {degradation}")
        block_degradation.append(degradation)
        total_degradation_in_performance += degradation
        total_params_reduction += params_reduction[block_idx]
        total_macs_reduction += macs_reduction[block_idx]

    relative_contributions = []
    weighted_importance_scores = []

    print(f"\nRelative contribution of each block to total {metric} degradation and parameter reduction:")
    for block_idx in range(blocks_number):
        rel_perf = (block_degradation[block_idx] / total_degradation_in_performance) * 100
        rel_params = 100 - params_reduction[block_idx]
        rel_macs = 100 - macs_reduction[block_idx]

        weighted_importance = (0.7 * rel_perf) + (0.2 * rel_params) + (0.1 * rel_macs)
        print(f'Block {block_idx} contributes {rel_perf:.2f}% to the total degradation in {metric} and reduces {params_reduction[block_idx]:.2f}% of parameters.')
        print(f'Weighted importance score for Block {block_idx}: {weighted_importance:.2f}')

        relative_contributions.append(rel_perf)
        weighted_importance_scores.append(weighted_importance)

    return weighted_importance_scores

In [None]:
def forward(self, x):
    """https://github.com/huggingface/pytorch-image-models/blob/054c763fcaa7d241564439ae05fbe919ed85e614/timm/models/vision_transformer.py#L79"""
    B, N, C = x.shape
    qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
    q, k, v = qkv.unbind(0)
    q, k = self.q_norm(q), self.k_norm(k)

    if self.fused_attn:
        x = F.scaled_dot_product_attention(
            q, k, v,
            dropout_p=self.attn_drop.p,
        )
    else:
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = attn @ v

    x = x.transpose(1, 2).reshape(B, N, -1) # original implementation: x = x.transpose(1, 2).reshape(B, N, C)
    x = self.proj(x)
    x = self.proj_drop(x)
    return x

# Override attention forward
for m in model.modules():
    if isinstance(m, timm.models.vision_transformer.Attention):
        m.forward = forward.__get__(m, timm.models.vision_transformer.Attention)
        
# Analyze contributions
relative_contribution = perplexity_analysis_with_contributions(
    model,
    device=device,
    metric='loss',
    measure_performance=compute_baseline_performance,
    data_loaders=[data_loader_train_one, data_loader_train_two, data_loader_train_three]
)

In [16]:
relative_contribution = [np.float64(68.08877057598382),
 np.float64(37.40626909822759),
 np.float64(32.281119143100256),
 np.float64(30.637165097452176),
 np.float64(30.020183495139563),
 np.float64(29.474453270385183),
 np.float64(29.440264146737952),
 np.float64(29.357529826182528),
 np.float64(29.954904338606084),
 np.float64(29.66567846836903),
 np.float64(30.0643648149045),
 np.float64(31.607907928914102)]

In [24]:
# [np.float64(67.84490239830508),
#  np.float64(37.50249339853453),
#  np.float64(32.31163056372445),
#  np.float64(30.57015590482878),
#  np.float64(29.96990099259839),
#  np.float64(29.514538402508663),
#  np.float64(29.43987769508422),
#  np.float64(29.401976151389682),
#  np.float64(29.970814767218833),
#  np.float64(29.648641127295463),
#  np.float64(30.145338707694115),
#  np.float64(31.678340094820577)]

In [17]:
def calculate_pruning_ratios(contributions, max_pruning_ratio=0.9, k=5):
    """
    Calculate pruning ratios based on intense nonlinear scaling (exponential decay) of the relative contributions.

    Parameters:
    - contributions (list): List of relative contributions (in percentages) of each block to total loss increase.
    - max_pruning_ratio (float): Maximum pruning ratio to be assigned to the least important layer. Default is 0.9.
    - k (int): Factor controlling the intensity of the scaling (larger k makes the ratio more intense).

    Returns:
    - pruning_ratios (list): List of pruning ratios for each block.
    """
    # Normalize the contributions to get values between 0 and 1
    total_contribution = sum(contributions)
    normalized_contributions = [contribution / total_contribution for contribution in contributions]

    # Apply exponential decay to magnify the effect for less important blocks
    pruning_factors = [np.exp(-k * nc) for nc in normalized_contributions]

    # Normalize the pruning factors so they stay within the max pruning ratio
    max_factor = max(pruning_factors)
    normalized_factors = [pf / max_factor for pf in pruning_factors]

    # Scale by the maximum pruning ratio
    pruning_ratios = [max_pruning_ratio * nf for nf in normalized_factors]

    pruning_ratios = [round(num, 2) for num in pruning_ratios]

    return pruning_ratios

In [18]:
def prune_model(trained_model, device, pruning_ratios, prune_method, data_loaders, mask_ratio=75):
    model = copy.deepcopy(trained_model)

    for param in model.parameters():
        param.requires_grad = True

    model_blocks, ignored_blocks, num_heads = get_blocks(model)
    model.to(device)

    pruning_info = {
        i: {"block": model_blocks[i], "pruning_ratio": ratio}
        for i, ratio in enumerate(pruning_ratios)
    }

    if prune_method == 'channel_pruning_Taylor_importance':
        imp = tp.importance.GroupTaylorImportance()

        combined_iter = RoundRobinLoader(data_loaders)

        if isinstance(imp, tp.importance.GroupTaylorImportance):
            model.zero_grad()
            model.train(True)

            print("Accumulating gradients for pruning...")
            for data_iter_step, (key, samples, labels) in enumerate(combined_iter):
                samples, labels = samples.to(device), labels.to(device)
                if data_iter_step >= 60: # 20 samples of each dataloader for now
                    break
                loss, _, _ = model(samples, mask_ratio=mask_ratio / 100)
                loss.backward()

        original_macs, original_nparams = tp.utils.count_ops_and_params(model, samples)

        for i, info in pruning_info.items():
            pruning_ratio = info["pruning_ratio"]
   

            ignored_layers_block = [pruning_info[j]["block"] for j in range(len(pruning_info)) if j != i]
            combined_ignored_layers = ignored_blocks + ignored_layers_block
            # print('combined_ignored_layers:   ', combined_ignored_layers)
            
            print(f"Pruning block {i} with pruning ratio: {pruning_ratio}")

            pruner = tp.pruner.MetaPruner(
                model,
                example_inputs=samples,
                importance=imp,
                pruning_ratio=pruning_ratio,
                ignored_layers=combined_ignored_layers,
                num_heads=num_heads,
                prune_num_heads=False,
                prune_head_dims=True
            )
            for g in pruner.step(interactive=True):
                g.prune()

            # print(f'A pruning process has been performed here with a {pruning_ratio} pruning ratio ...  ')
            
            for m in model.modules():
                if isinstance(m, timm.models.vision_transformer.Attention):
                    m.num_heads = pruner.num_heads[m.qkv]
                    m.head_dim = m.qkv.out_features // (3 * m.num_heads)

            macs, nparams = tp.utils.count_ops_and_params(model, samples)

            print(f"MACs: {macs / 1e9:.2f} G, #Params: {nparams / 1e3:.2f} K")
            print(f"Parameter reduction: {((original_nparams - nparams) / original_macs * 100):.2f}%")
            print(f"MACs reduction: {((original_macs - macs) / original_macs * 100):.2f}%")
        

                
        del samples, labels
        torch.cuda.empty_cache()

    return model, macs, nparams

In [19]:
max_pruning_ratio = 0.99 # Maximum pruning ratio (99%) 
k = 5 # Controls the intensity of the scaling

pruning_ratios = calculate_pruning_ratios(relative_contribution, max_pruning_ratio, k)

# Print the pruning ratios for each block
for i, ratio in enumerate(pruning_ratios):
    print(f"Block {i} Pruning Ratio: {ratio:.4f}")

Block 0 Pruning Ratio: 0.6200
Block 1 Pruning Ratio: 0.9000
Block 2 Pruning Ratio: 0.9600
Block 3 Pruning Ratio: 0.9700
Block 4 Pruning Ratio: 0.9800
Block 5 Pruning Ratio: 0.9900
Block 6 Pruning Ratio: 0.9900
Block 7 Pruning Ratio: 0.9900
Block 8 Pruning Ratio: 0.9800
Block 9 Pruning Ratio: 0.9900
Block 10 Pruning Ratio: 0.9800
Block 11 Pruning Ratio: 0.9600


In [20]:
data_loaders = [data_loader_train_one, data_loader_train_two, data_loader_train_three]
pruned_model, macs, nparams = prune_model(model, device, pruning_ratios, 'channel_pruning_Taylor_importance', data_loaders)

Accumulating gradients for pruning...
Pruning block 0 with pruning ratio: 0.62


 Torch-Pruning will prune the last non-singleton dimension of these parameters. If you wish to change this behavior, please provide an unwrapped_parameters argument.


MACs: 3.27 G, #Params: 44049.99 K
Parameter reduction: 0.06%
MACs reduction: 3.04%
Pruning block 1 with pruning ratio: 0.9
MACs: 3.13 G, #Params: 41208.23 K
Parameter reduction: 0.14%
MACs reduction: 7.35%
Pruning block 2 with pruning ratio: 0.96
MACs: 2.97 G, #Params: 38174.76 K
Parameter reduction: 0.23%
MACs reduction: 11.93%
Pruning block 3 with pruning ratio: 0.97
MACs: 2.82 G, #Params: 35104.38 K
Parameter reduction: 0.32%
MACs reduction: 16.56%
Pruning block 4 with pruning ratio: 0.98
MACs: 2.66 G, #Params: 32012.48 K
Parameter reduction: 0.41%
MACs reduction: 21.22%
Pruning block 5 with pruning ratio: 0.99
MACs: 2.56 G, #Params: 29933.78 K
Parameter reduction: 0.48%
MACs reduction: 24.30%
Pruning block 6 with pruning ratio: 0.99
MACs: 2.45 G, #Params: 27855.08 K
Parameter reduction: 0.54%
MACs reduction: 27.38%
Pruning block 7 with pruning ratio: 0.99
MACs: 2.35 G, #Params: 25776.38 K
Parameter reduction: 0.60%
MACs reduction: 30.45%
Pruning block 8 with pruning ratio: 0.98
MAC

In [21]:
data_loaders = [data_loader_train_one, data_loader_train_two, data_loader_train_three]

loss = compute_baseline_performance(pruned_model, 'cuda', data_loaders)

loss

                                                          

6.404765754683404

In [23]:
images = next(iter(data_loaders[0]))[0].to(device)
model.to(device)

base_macs, base_nparams = tp.utils.count_ops_and_params(model, images)
macs, nparams = tp.utils.count_ops_and_params(pruned_model, images)

print(f"MACs: {base_macs/1e9} G -> {macs/1e9} G, #Params: {base_nparams/1e6} M -> {nparams/1e6} M")
p_ratio = ((base_nparams - nparams) / base_nparams * 100)

print(f"Overall parameter reduction: {(p_ratio):.2f}%")
print(f"Overall MACs reduction: {((base_macs - macs) / base_macs * 100):.2f}%")
            

MACs: 3.376390688 G -> 1.7748810898125 G, #Params: 46.008064 M -> 14.480397 M
Overall parameter reduction: 68.53%
Overall MACs reduction: 47.43%


In [25]:
torch.save(pruned_model, f'our_pruned_models/pruned_autoencoder/Vit_pruned_{(p_ratio):.2f}%.pth')

In [8]:
def compare_model_weights(model_a, model_b, verbose=True):
    """
    Compare the weights of two PyTorch models.

    Args:
        model_a (torch.nn.Module): First model.
        model_b (torch.nn.Module): Second model.
        verbose (bool): If True, prints differences.

    Returns:
        differences (list): List of tuples (param_name, max_diff) where differences were found.
    """
    differences = []

    state_dict_a = model_a.state_dict()
    state_dict_b = model_b.state_dict()

    keys_a = set(state_dict_a.keys())
    keys_b = set(state_dict_b.keys())

    # Check for missing parameters
    if keys_a != keys_b:
        missing_in_b = keys_a - keys_b
        missing_in_a = keys_b - keys_a

        if missing_in_b:
            print(f"Parameters in model_a but missing in model_b: {missing_in_b}")
        if missing_in_a:
            print(f"Parameters in model_b but missing in model_a: {missing_in_a}")

        # Only compare common keys
        common_keys = keys_a.intersection(keys_b)
    else:
        common_keys = keys_a

    # Compare weights
    for key in common_keys:
        tensor_a = state_dict_a[key]
        tensor_b = state_dict_b[key]

        if not torch.allclose(tensor_a, tensor_b, atol=1e-6):
            max_diff = (tensor_a - tensor_b).abs().max().item()
            differences.append((key, max_diff))
            if verbose:
                print(f"[Value mismatch] {key} | Max diff: {max_diff:.2e}")

    if not differences and verbose:
        print("✅ Models are identical (within tolerance)!")

    return differences


original_model = torch.load('/home/ict317-3/Mohammad/Tiny-WFMs/our_pruned_models/pruned_autoencoder/Vit_pruned_68.53%.pth', weights_only=False)
pruned_model = torch.load('/home/ict317-3/Mohammad/Tiny-WFMs/pruned_results/sig_identification/best_model.pth', weights_only=False)

original_model.to('cuda')
pruned_model.to('cuda')
compare_model_weights(original_model, pruned_model)

Parameters in model_a but missing in model_b: {'decoder_blocks.6.norm2.bias', 'decoder_blocks.6.norm1.weight', 'decoder_blocks.4.norm2.bias', 'decoder_blocks.4.norm1.weight', 'decoder_blocks.3.attn.proj.bias', 'decoder_blocks.7.norm1.bias', 'decoder_blocks.1.norm1.bias', 'decoder_blocks.7.norm1.weight', 'decoder_pred.0.weight', 'decoder_embed.bias', 'decoder_blocks.7.attn.proj.weight', 'decoder_blocks.5.mlp.fc1.bias', 'patch_embed.2.proj.bias', 'decoder_blocks.1.attn.proj.bias', 'decoder_blocks.6.norm2.weight', 'decoder_embed.weight', 'decoder_pos_embed', 'decoder_blocks.1.norm2.bias', 'decoder_blocks.6.attn.proj.weight', 'decoder_blocks.4.attn.qkv.weight', 'decoder_blocks.4.norm2.weight', 'patch_embed.1.proj.bias', 'decoder_pred.1.weight', 'decoder_pred.2.weight', 'decoder_norm.bias', 'patch_embed.2.proj.weight', 'decoder_blocks.3.mlp.fc1.weight', 'decoder_blocks.6.attn.qkv.weight', 'decoder_blocks.5.attn.proj.weight', 'decoder_blocks.2.attn.qkv.weight', 'decoder_blocks.5.norm2.weight

[('pos_embed', 1.0822968482971191),
 ('cls_token', 0.06781171262264252),
 ('norm.weight', 0.6712888479232788),
 ('norm.bias', 0.13735242187976837)]