In [1]:
from model.pytorch_pretrained_vit import ViT
import dataloader

import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from copy import deepcopy

from utils import AverageMeter, get_weights_copy

import time

import torch.nn as nn
import torch.optim as optim

import torch
import os

In [2]:
def get_prun_idx(layer, sparsity=0.5):
    idxs = None
    lambda_ = None

    with torch.no_grad():
        p = layer.weight.data
        p = torch.linalg.norm(p, dim=1)
        lambda_ = p.sort()[0][int(sparsity * p.shape[0])]
        idxs = torch.ones_like(p)
        idxs[p.abs() < lambda_] = 0
    return idxs, lambda_

def attn_prun_idx(proj, sparsity=0.5, n_heads=12):
    idxs = None
    lambda_ = None

    with torch.no_grad():
        p = proj.weight.data
        dims = p.shape[0]
        p = torch.linalg.norm(p, dim=1)
        idxs = torch.ones((dims))
        lams = []
        for i in range(n_heads):
            head = p[i*(dims//n_heads):i*(dims//n_heads)+(dims//n_heads)]
            lam = head.sort().values[int(sparsity * head.shape[0])]
            indices = head.sort().indices[:int(sparsity * head.shape[0])]
            idxs[indices + i*(dims//n_heads)] = 0
            lams.append(lam)
    return idxs, lams

def conv_prun_idx(layer, sparsity=0.5, structured=True):

    with torch.no_grad():
        if structured:
            p = layer.weight.view(layer.weight.shape[0], -1)
            p = torch.linalg.norm(p, dim=1)
            lambda_ = p.sort()[0][int(sparsity * p.shape[0])]
            conv_idxs = torch.ones_like(p)
            conv_idxs[p.abs() < lambda_] = 0
        else:
            p = layer.weight.abs()
            lambda_ = p.view(-1).sort()[0][int(sparsity * p.view(-1).shape[0])]
            conv_idxs = torch.ones_like(p)
            conv_idxs[p <= lambda_] = 0

    return conv_idxs, lambda_

def fc_prun(layer, idxs, dim='in'):
    assert dim in ['in', 'out']

    with torch.no_grad():
        if dim == 'out':
            layer.weight[idxs == 0] = 0
        elif dim == 'in':
            layer.weight[:, idxs == 0] = 0

        if layer.bias != None and dim == 'out':
            layer.bias[idxs == 0] = 0
            
def pos_prun(layer, idxs):
    with torch.no_grad():
        layer.pos_embedding[:,:, idxs == 0] = 0
        
        
def prune_vit(net, ratio):
    total_idxs = {}
    total_lambdas = {}

    # embedding pruning
    idxs, lambda_ = conv_prun_idx(net.patch_embedding, sparsity=ratio)
    total_idxs['embedding'] = idxs
    total_lambdas['embedding'] = lambda_

    fc_prun(net.patch_embedding, idxs, dim='out')
    pos_prun(net.positional_embedding, idxs)

    total_idxs['blocks'] = []
    total_lambdas['blocks'] = []

    #idxs dim 2
    for layer in net.transformer.blocks:
        block_idxs = {}
        block_lambdas = {}

        # layer norm 1
        fc_prun(layer.norm1, idxs, dim='out')

        # attn
        fc_prun(layer.attn.proj_q, idxs, dim='in')
        fc_prun(layer.attn.proj_k, idxs, dim='in')
        fc_prun(layer.attn.proj_v, idxs, dim='in')
        block_idxs['proj_in'] = idxs

        v_idx, v_lam = attn_prun_idx(layer.attn.proj_v, sparsity=ratio, n_heads=layer.attn.n_heads)
        block_idxs['proj_v'] = v_idx

        fc_prun(layer.attn.proj_v, v_idx, dim='out')

        # projection
        fc_prun(layer.proj, v_idx, dim='in')
        fc_prun(layer.proj, idxs, dim='out')

        # pwff
        pwff_idxs, pwff_lam = get_prun_idx(layer.pwff.fc1,sparsity=ratio)

        fc_prun(layer.pwff.fc1, pwff_idxs, dim='out')
        fc_prun(layer.pwff.fc2, pwff_idxs, dim='in')

        block_idxs['pwff_out'] = pwff_idxs
        block_lambdas['pwff_out'] = pwff_lam
        fc_prun(layer.pwff.fc2, idxs, dim='out')

        # layer norm 2
        fc_prun(layer.norm2, idxs, dim='out')

        total_idxs['blocks'].append(block_idxs)
        total_lambdas['blocks'].append(block_lambdas)

    fc_prun(net.fc, idxs, dim='in')
    
    return total_idxs, total_lambdas

In [3]:
trainset, testset, num_classes = getattr(dataloader, 'imagenet')(batch_size=16)

Using augmented IMAGENET.


In [4]:
for X, y in trainset:
    break

In [None]:
total_idxs['blocks'] = []
total_lambdas['blocks'] = []

In [6]:
model_args = {
        'num_classes' : 10,
        'image_size' : 224,
        'in_channels' : 3,
        'dropout_rate': 0.,
        'pretrained': False
    }

net = ViT('B_16_imagenet1k', **model_args)

In [8]:
net.load_state_dict(torch.load('/workspace/paper_works/work_results/imagenet_vit_B_16_imagenet1k/best_state.ptl')['state_dict'])

<All keys matched successfully>

In [9]:
def split_last(x, shape):
    "split the last dimension to given shape"
    shape = list(shape)
    assert shape.count(-1) <= 1
    if -1 in shape:
        shape[shape.index(-1)] = int(x.size(-1) / -np.prod(shape))
    return x.view(*x.size()[:-1], *shape)

def merge_last(x, n_dims):
    "merge the last n_dims to a dimension"
    s = x.size()
    assert n_dims > 1 and n_dims < len(s)
    return x.view(*s[:-n_dims], -1)

In [10]:
import numpy as np
import torch.nn.functional as F

In [163]:
out = net.patch_embedding(X)
# print(out.shape)
out = out.flatten(2).transpose(1, 2)
# print(out.shape)
out = torch.cat((net.class_token.expand(16, -1, -1), out), dim=1)
out = net.transformer.blocks[0].norm1(out)

# q_idx, q_lam = attn_prun_idx(net.transformer.blocks[0].attn.proj_q)
# k_idx, k_lam = attn_prun_idx(net.transformer.blocks[0].attn.proj_k)
v_idx, v_lam = attn_prun_idx(net.transformer.blocks[0].attn.proj_v)

# fc_prun(net.transformer.blocks[0].attn.proj_q, q_idx, dim='out')
# fc_prun(net.transformer.blocks[0].attn.proj_k, k_idx, dim='out')
fc_prun(net.transformer.blocks[0].attn.proj_v, v_idx, dim='out')

q, k, v = net.transformer.blocks[0].attn.proj_q(out), net.transformer.blocks[0].attn.proj_k(out), net.transformer.blocks[0].attn.proj_v(out)

print(q.shape, k.shape, v.shape)
q, k, v = (split_last(x, (net.transformer.blocks[0].attn.n_heads, -1)).transpose(1, 2) for x in [q, k, v])
print(q.shape, k.shape, v.shape)

torch.Size([16, 197, 768]) torch.Size([16, 197, 768]) torch.Size([16, 197, 768])
torch.Size([16, 12, 197, 64]) torch.Size([16, 12, 197, 64]) torch.Size([16, 12, 197, 64])


In [164]:
scoresscores = q @ k.transpose(-2, -1) / np.sqrt(k.size(-1))
scores = F.softmax(scores, dim=-1)
h = (scores @ v).transpose(1, 2).contiguous()
h = merge_last(h, 2)

In [165]:
for i in range(768):
    print(i//64, i%64)
    print(v[:,i//64,:,i%64].sum())
    print(h[:,:,i].sum())

0 0
tensor(447.9012, grad_fn=<SumBackward0>)
tensor(447.9011, grad_fn=<SumBackward0>)
0 1
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
0 2
tensor(-208.1800, grad_fn=<SumBackward0>)
tensor(-208.1801, grad_fn=<SumBackward0>)
0 3
tensor(-153.6421, grad_fn=<SumBackward0>)
tensor(-153.6420, grad_fn=<SumBackward0>)
0 4
tensor(-362.3188, grad_fn=<SumBackward0>)
tensor(-362.3185, grad_fn=<SumBackward0>)
0 5
tensor(-199.5305, grad_fn=<SumBackward0>)
tensor(-199.5305, grad_fn=<SumBackward0>)
0 6
tensor(417.3267, grad_fn=<SumBackward0>)
tensor(417.3265, grad_fn=<SumBackward0>)
0 7
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
0 8
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
0 9
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
0 10
tensor(114.8449, grad_fn=<SumBackward0>)
tensor(114.8450, grad_fn=<SumBackward0>)
0 11
tensor(256.9119, grad_fn=<SumBackward0>)
tensor(256.9118, grad_fn=<SumBackward0>)
0 12
t

tensor(0., grad_fn=<SumBackward0>)
2 52
tensor(152.1068, grad_fn=<SumBackward0>)
tensor(152.1067, grad_fn=<SumBackward0>)
2 53
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 54
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 55
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 56
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 57
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 58
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 59
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 60
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
2 61
tensor(-350.9072, grad_fn=<SumBackward0>)
tensor(-350.9071, grad_fn=<SumBackward0>)
2 62
tensor(215.5752, grad_fn=<SumBackward0>)
tensor(215.5751, grad_fn=<SumBackward0>)
2 63
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
3 0
tensor(740.7640, grad_f

tensor(-378.3061, grad_fn=<SumBackward0>)
6 3
tensor(586.0024, grad_fn=<SumBackward0>)
tensor(586.0022, grad_fn=<SumBackward0>)
6 4
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
6 5
tensor(832.2426, grad_fn=<SumBackward0>)
tensor(832.2424, grad_fn=<SumBackward0>)
6 6
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
6 7
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
6 8
tensor(-406.2493, grad_fn=<SumBackward0>)
tensor(-406.2491, grad_fn=<SumBackward0>)
6 9
tensor(1058.0688, grad_fn=<SumBackward0>)
tensor(1058.0684, grad_fn=<SumBackward0>)
6 10
tensor(-40.5747, grad_fn=<SumBackward0>)
tensor(-40.5747, grad_fn=<SumBackward0>)
6 11
tensor(-98.8386, grad_fn=<SumBackward0>)
tensor(-98.8385, grad_fn=<SumBackward0>)
6 12
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
6 13
tensor(-889.4926, grad_fn=<SumBackward0>)
tensor(-889.4925, grad_fn=<SumBackward0>)
6 14
tensor(0., grad_fn=<SumBackward0>)
tensor(0.,

tensor(892.9361, grad_fn=<SumBackward0>)
9 29
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 30
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 31
tensor(245.4294, grad_fn=<SumBackward0>)
tensor(245.4294, grad_fn=<SumBackward0>)
9 32
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 33
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 34
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 35
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 36
tensor(-1089.8848, grad_fn=<SumBackward0>)
tensor(-1089.8848, grad_fn=<SumBackward0>)
9 37
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 38
tensor(123.7537, grad_fn=<SumBackward0>)
tensor(123.7537, grad_fn=<SumBackward0>)
9 39
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 40
tensor(0., grad_fn=<SumBackward0>)
tensor(0., grad_fn=<SumBackward0>)
9 41
tensor(-626.46

In [150]:
v.shape, h.shape

(torch.Size([16, 12, 197, 64]), torch.Size([16, 197, 768]))

In [110]:

# (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -trans-> (B, S, H, W)
h = (scores @ v).transpose(1, 2).contiguous()
print(h.shape)
# -merge-> (B, S, D)
h = merge_last(h, 2)
print(h.shape)

torch.Size([16, 197, 12, 64])
torch.Size([16, 197, 768])


In [190]:
net.patch_embedding.weight.shape

torch.Size([768, 3, 16, 16])

In [172]:
out = net.patch_embedding(X)
print(out.shape)
out = out.flatten(2).transpose(1, 2)
print(out.shape)
out = torch.cat((net.class_token.expand(16, -1, -1), out), dim=1)
print(out.shape)
out = net.positional_embedding(out)
print(out.shape)
out = net.transformer.blocks[0].norm1(out)

q, k, v = net.transformer.blocks[0].attn.proj_q(out), net.transformer.blocks[0].attn.proj_k(out), net.transformer.blocks[0].attn.proj_v(out)
q, k, v = (split_last(x, (12, -1)).transpose(1, 2) for x in [q, k, v])


out = net.transformer.blocks[0].attn(out, None)
out = net.transformer.blocks[0].proj(out)
# residual1
out = net.transformer.blocks[0].norm2(out)
out = net.transformer.blocks[0].pwff(out)
# residual2

torch.Size([16, 768, 14, 14])
torch.Size([16, 196, 768])
torch.Size([16, 197, 768])
torch.Size([16, 197, 768])


In [191]:
# patch embedding pruning
idxs, lambda_ = conv_prun_idx(net.patch_embedding, sparsity=0.5)
fc_prun(net.patch_embedding, idxs, dim='out')
print(idxs)
pos_prun(net.positional_embedding, idxs)

tensor([1., 1., 1., 1., 0., 0., 1., 0., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0.,
        1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 0., 0., 0.,
        1., 0., 1., 1., 1., 0., 1., 1., 1., 0., 0., 1., 1., 0., 0., 1., 0., 0.,
        1., 1., 1., 0., 1., 0., 0., 0., 1., 0., 1., 1., 1., 1., 0., 0., 1., 0.,
        0., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 0., 0., 1., 0.,
        1., 1., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1.,
        0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 1., 1., 0., 0.,
        1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 0., 1.,
        0., 0., 1., 0., 1., 0., 0., 0., 0., 0., 0., 0., 1., 1., 1., 1., 1., 0.,
        0., 0., 1., 0., 1., 1., 0., 1., 0., 0., 1., 0., 0., 1., 1., 1., 0., 0.,
        1., 0., 0., 0., 0., 0., 1., 0., 1., 1., 1., 0., 0., 1., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 1., 1., 1., 

In [192]:
out = net.patch_embedding(X)
print(out.shape)

torch.Size([16, 768, 14, 14])


In [200]:
(out.mean(dim=0).mean(dim=1).mean(dim=1) == 0).sum()

tensor(384)

In [201]:
net.fc.weight.shape

torch.Size([10, 768])

In [17]:
out = net.patch_embedding(X)
print(out.shape)
out = out.flatten(2).transpose(1, 2)
print(out.shape)
out = torch.cat((net.class_token.expand(16, -1, -1), out), dim=1)
print(out.shape)
out = net.positional_embedding(out)
print(out.shape)
out = net.transformer.blocks[0].norm1(out)
print(out)
out = net.transformer.blocks[0].attn(out, None)
print(out)
out = net.transformer.blocks[0].proj(out)
print(out)
# residual1
# out = net.transformer.blocks[0].norm2(out)
# print(out)
# out = net.transformer.blocks[0].pwff(out)
# print(out)
# residual2

torch.Size([16, 768, 14, 14])
torch.Size([16, 196, 768])
torch.Size([16, 197, 768])
torch.Size([16, 197, 768])
tensor([[[-2.5086e-03,  3.0151e-02,  0.0000e+00,  ..., -4.8156e-03,
          -2.2017e-02,  0.0000e+00],
         [ 1.0760e-02, -1.9680e-01,  0.0000e+00,  ..., -6.9125e-02,
          -4.9504e-01,  0.0000e+00],
         [ 1.6737e-01, -1.2083e-01,  0.0000e+00,  ...,  1.6194e-02,
          -3.5010e-01,  0.0000e+00],
         ...,
         [-2.4099e-01,  6.9428e-02,  0.0000e+00,  ...,  1.7375e-01,
           3.6804e-03,  0.0000e+00],
         [ 1.7674e-01,  1.9373e-02,  0.0000e+00,  ..., -8.1708e-02,
          -3.0842e-01,  0.0000e+00],
         [ 5.1517e-02, -1.8945e-01,  0.0000e+00,  ...,  2.7713e-02,
          -1.4023e-01,  0.0000e+00]],

        [[-2.5086e-03,  3.0151e-02,  0.0000e+00,  ..., -4.8156e-03,
          -2.2017e-02,  0.0000e+00],
         [ 3.8478e-02, -1.1407e-01,  0.0000e+00,  ..., -2.1078e-01,
          -3.2777e-01,  0.0000e+00],
         [ 2.1428e-01, -3.2662e-0

In [21]:
out.mean(axis=1).mean(axis=0)

tensor([ 3.2373e-02,  1.3596e-01,  0.0000e+00, -3.3496e-02,  0.0000e+00,
         0.0000e+00,  0.0000e+00,  4.1707e-02,  0.0000e+00,  0.0000e+00,
         1.3865e-01, -1.6313e-01,  0.0000e+00, -1.0182e-01,  4.2756e-01,
         0.0000e+00, -3.4757e-01,  0.0000e+00, -1.1232e-01,  1.6594e-05,
         0.0000e+00,  4.2793e-02,  0.0000e+00,  0.0000e+00, -1.0121e-02,
         0.0000e+00, -2.1562e-01,  3.9607e-01,  0.0000e+00,  0.0000e+00,
         1.0195e-01,  0.0000e+00,  0.0000e+00, -5.8820e-01, -1.8823e-01,
         0.0000e+00,  0.0000e+00, -5.7606e-02,  0.0000e+00, -2.4736e-01,
        -7.5979e-02,  1.2016e-01,  0.0000e+00,  2.2287e-02,  0.0000e+00,
         0.0000e+00, -1.2439e-01,  2.1837e-02,  0.0000e+00,  1.9932e-01,
         1.5738e-01,  0.0000e+00,  0.0000e+00, -7.1894e-02, -1.2136e-01,
         0.0000e+00, -2.0392e-01,  4.5554e-03,  0.0000e+00,  0.0000e+00,
        -2.1817e-01, -1.8015e-01,  1.1111e-01,  7.7583e-02,  0.0000e+00,
         1.0447e-02,  0.0000e+00,  0.0000e+00,  0.0

In [13]:
total_idxs = {}
total_lambdas = {}

# embedding pruning
idxs, lambda_ = conv_prun_idx(net.patch_embedding, sparsity=0.5)
total_idxs['embedding'] = idxs
total_lambdas['embedding'] = lambda_

fc_prun(net.patch_embedding, idxs, dim='out')
pos_prun(net.positional_embedding, idxs)

total_idxs['blocks'] = []
total_lambdas['blocks'] = []

#idxs dim 2
for layer in net.transformer.blocks:
    block_idxs = {}
    block_lambdas = {}
    
    # layer norm 1
    fc_prun(layer.norm1, idxs, dim='out')
    
    # attn
    fc_prun(layer.attn.proj_q, idxs, dim='in')
    fc_prun(layer.attn.proj_k, idxs, dim='in')
    fc_prun(layer.attn.proj_v, idxs, dim='in')
    block_idxs['proj_in'] = idxs
    
    v_idx, v_lam = attn_prun_idx(layer.attn.proj_v)
    block_idxs['proj_in'] = v_idx
    
    fc_prun(layer.attn.proj_v, v_idx, dim='out')
    
    # projection
    fc_prun(layer.proj, v_idx, dim='in')
    fc_prun(layer.proj, idxs, dim='out')
    break

{'embedding': tensor([1., 1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 1., 0.,
         1., 1., 0., 1., 0., 0., 1., 0., 1., 1., 0., 0., 1., 0., 0., 1., 1., 0.,
         0., 1., 0., 1., 1., 1., 0., 1., 0., 0., 1., 1., 0., 1., 1., 0., 0., 1.,
         1., 0., 1., 1., 0., 0., 1., 1., 1., 1., 0., 1., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 0., 0., 0., 1., 1.,
         0., 0., 0., 0., 1., 1., 1., 0., 1., 0., 0., 0., 1., 1., 1., 0., 0., 0.,
         0., 0., 0., 1., 0., 1., 0., 0., 1., 0., 1., 1., 1., 0., 0., 0., 1., 0.,
         1., 0., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.,
         0., 0., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1., 1., 0., 1., 0., 1.,
         1., 0., 1., 0., 0., 0., 1., 0., 0., 1., 0., 0., 0., 1., 1., 1., 1., 0.,
         0., 1., 1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 0., 0., 0., 0., 1.,
         1., 0., 0., 0., 1., 1., 0., 0., 0., 0., 0., 1., 1., 0., 1., 1., 1., 1.,
         0., 0.

In [28]:
ratio = .5

In [29]:
total_idxs = {}
total_lambdas = {}

# embedding pruning
idxs, lambda_ = conv_prun_idx(net.patch_embedding, sparsity=ratio)
total_idxs['embedding'] = idxs
total_lambdas['embedding'] = lambda_

fc_prun(net.patch_embedding, idxs, dim='out')
pos_prun(net.positional_embedding, idxs)

total_idxs['blocks'] = []
total_lambdas['blocks'] = []

#idxs dim 2
for layer in net.transformer.blocks:
    block_idxs = {}
    block_lambdas = {}
    
    # layer norm 1
    fc_prun(layer.norm1, idxs, dim='out')
    
    # attn
    fc_prun(layer.attn.proj_q, idxs, dim='in')
    fc_prun(layer.attn.proj_k, idxs, dim='in')
    fc_prun(layer.attn.proj_v, idxs, dim='in')
    block_idxs['proj_in'] = idxs
    
    v_idx, v_lam = attn_prun_idx(layer.attn.proj_v, sparsity=ratio, n_heads=layer.attn.n_heads)
    block_idxs['proj_v'] = v_idx
    
    fc_prun(layer.attn.proj_v, v_idx, dim='out')
    
    # projection
    fc_prun(layer.proj, v_idx, dim='in')
    fc_prun(layer.proj, idxs, dim='out')
    
    # pwff
    pwff_idxs, pwff_lam = get_prun_idx(layer.pwff.fc1,sparsity=ratio)
    
    fc_prun(layer.pwff.fc1, pwff_idxs, dim='out')
    fc_prun(layer.pwff.fc2, pwff_idxs, dim='in')
    
    block_idxs['pwff_out'] = pwff_idxs
    block_lambdas['pwff_out'] = pwff_lam
    fc_prun(layer.pwff.fc2, idxs, dim='out')
    
    # layer norm 2
    fc_prun(layer.norm2, idxs, dim='out')
    
    total_idxs['blocks'].append(block_idxs)
    total_lambdas['blocks'].append(block_lambdas)

fc_prun(net.fc, idxs, dim='in')

In [31]:
total_lambdas

{'embedding': tensor(2.1611),
 'blocks': [{'pwff_out': tensor(2.6624)},
  {'pwff_out': tensor(2.6237)},
  {'pwff_out': tensor(2.6009)},
  {'pwff_out': tensor(2.6089)},
  {'pwff_out': tensor(2.5946)},
  {'pwff_out': tensor(2.5814)},
  {'pwff_out': tensor(2.5279)},
  {'pwff_out': tensor(2.4471)},
  {'pwff_out': tensor(2.3706)},
  {'pwff_out': tensor(2.4565)},
  {'pwff_out': tensor(2.7783)},
  {'pwff_out': tensor(2.4957)}]}