In [1]:
import torch
import torch.nn as nn
import torchvision.utils
import enum
import csv
import os
from collections import OrderedDict


def get_sparse_weights(weight,N=2,M=4, get_mask = False):
    length = weight.numel()
    
    group = int(length/M)

    weight_temp = weight.detach().abs().reshape(group, M)
    index = torch.argsort(weight_temp, dim=1)[:, :int(M-N)]

    w_b = torch.ones(weight_temp.shape, device=weight_temp.device)
    w_b = w_b.scatter_(dim=1, index=index, value=0).reshape(weight.shape)

    mask = w_b

    if get_mask:
        return w_b*weight, mask
    
    return w_b*weight

  from .autonotebook import tqdm as notebook_tqdm


In [33]:
folder = '/usr/scratch/abhimanyu/work/google_nm_sparse_training/pytorch-image-models-sparsity/vit_model_full_1105/SRSTE_NM_1_128/'

all_files = []
for file in os.listdir(folder):
    if file.endswith(".pt"):
        # print path name of selected files
        all_files.append(file)


stat_file = os.path.join(folder, 'model_stats.csv')
# state = torch.load(file_name)
print(stat_file)

/usr/scratch/abhimanyu/work/google_nm_sparse_training/pytorch-image-models-sparsity/vit_model_full_1105/SRSTE_NM_1_128/model_stats.csv


In [34]:
# print(state['module.blocks.0.mlp.fc1.weight_grad'])
# print(state)
prev_weights_array = OrderedDict()
prev_grad_array = OrderedDict()

os.remove(stat_file)
for file_name in all_files:

    state = torch.load(os.path.join(folder, file_name))
    curr_weights_array = OrderedDict()
    curr_grad_array = OrderedDict()
    # print(state)
    for name in state:
        # print(name)
        if name == 'loss':
            current_loss = state['loss'].item()
            # rowd.update({f'loss' : })
        elif 'grad' in name:
            ## Gradients of the weights
            curr_grad_array.update({name:state[name].detach().clone()})
        elif 'weight' in name:
            curr_weights_array.update({name:state[name].detach().clone()})


    if(len(prev_weights_array) != 0 and len(prev_grad_array) != 0 ):
        rowd = OrderedDict(loss=current_loss)
        # print(len(curr_weights_array), len(curr_grad_array),len(prev_weights_array),len(prev_grad_array)) 
        for weight_name,grad_name,weight_last_name,grad_last_name in zip(curr_weights_array,curr_grad_array,prev_weights_array,prev_grad_array):
            
            name_fc = weight_name.removeprefix('module.').removesuffix('.weight')

            grad = curr_grad_array[grad_name] 
            rowd.update({f'grad_mean_' + str(name_fc) :  torch.mean(grad).cpu().numpy()})  
            rowd.update({f'grad_std_' + str(name_fc) :  torch.std(grad).cpu().numpy()})  
            rowd.update({f'grad_l2norm_' + str(name_fc) :  torch.norm(grad).cpu().numpy()})  
            rowd.update({f'grad_linfnorm_' + str(name_fc) :  torch.max(grad).cpu().numpy()}) 

            # for name, param in model.named_parameters():
            weight = curr_weights_array[weight_name]
            curr_weights, curr_sparse_mask = get_sparse_weights(weight, 1, 128,  get_mask=True)
            rowd.update({f'sparse_weight_mean_' + str(name_fc) : torch.mean(curr_weights).cpu().numpy()})
            rowd.update({f'sparse_weight_std_' + str(name_fc) :  torch.std(curr_weights).cpu().numpy()})

            ## W2-W1
            weight_last = prev_weights_array[weight_last_name] 
            prev_weights_fc, prev_sparse_mask = get_sparse_weights(weight_last.detach().clone(), 1, 128, get_mask=True)
            weight_diff = curr_weights - prev_weights_fc 
            # print(f'{name} : {weight_diff}')
            rowd.update({f'weight_diff_l2_norm_' + str(name_fc) :  torch.norm(weight_diff, p=2).cpu().numpy()}) 
            rowd.update({f'weight_diff_linf_norm_' + str(name_fc) :  torch.max(weight_diff).cpu().numpy()}) 
            rowd.update({f'weight_diff_std_norm_' + str(name_fc) :  torch.std(weight_diff).cpu().numpy()}) 

            ##sparse_mask
            mask_diff = curr_sparse_mask - prev_sparse_mask
            rowd.update({f'SAD_L1_' + str(name_fc) :  torch.norm(mask_diff, p=1).cpu().numpy()}) 
            rowd.update({f'SAD_L2_' + str(name_fc) :  torch.norm(mask_diff, p=2).cpu().numpy()})
            rowd.update({f'SAD_std_' + str(name_fc) :  torch.std(mask_diff).cpu().numpy()})

            ## G2 - G1
            grad_last = prev_grad_array[grad_last_name] 
            grad_diff = grad - grad_last
            rowd.update({f'grad_diff_l2_' + str(name_fc) :  torch.norm(grad_diff, p=2).cpu().numpy()}) 
            rowd.update({f'grad_diff_mean_' + str(name_fc) :  torch.mean(grad_diff).cpu().numpy()}) 
            rowd.update({f'grad_diff_std_' + str(name_fc) :  torch.std(grad_diff).cpu().numpy()})  

        ## Write to the file
        with open(stat_file, mode='a') as cf:
            dw = csv.DictWriter(cf, fieldnames=rowd.keys())
            # Get the size of the file
            file_size = os.path.getsize(stat_file)
        
            # Check if the file is empty
            if file_size == 0:
                dw.writeheader()
            dw.writerow(rowd)
    prev_weights_array = curr_weights_array 
    prev_grad_array = curr_grad_array

      