In [None]:
import numpy as np 
import matplotlib.pyplot as plt 
import glob 
import time 
import h5py
import cv2
import torchinfo 
import torch
import torch.nn as nn
import argparse
import torch.optim as optim
import torch.backends.cudnn as cudnn
from torch.utils.data import DataLoader
from torch.autograd import Variable
import os
from hsi_dataset import TrainDataset, ValidDataset
from architecture import *
from utils import *

import datetime
import wandb
import time 

In [None]:
models = ['hscnn_plus', 'restormer','mst_plus_plus' ,'vitmstpp', 'vitmstpp_pad']

In [None]:
models_dict = {}
inputs = [128, 256, 384, 512, 768, 1024]
macs_dict = {}
for md in models:
    print(' ')
    print(' ') 
    print(md)
    model = model_generator(md).cuda()
    macs_dict[md] = []
    for i in inputs: 
        torch.cuda.empty_cache()
        nm = md +'_' + str(i)
        print(nm)
        models_dict[nm] = {}
        print('TORCHINFO')
        try: 
            summary_str = str(torchinfo.summary(model, input_size=(2,3,i,i)))
            models_dict[nm]['MACs'] = summary_str.split('Total mult-adds ')[1].split('\n')[0]
            models_dict[nm]['tot_size'] = summary_str.split('Estimated Total Size ')[1].split('\n')[0]
            models_dict[nm]['trainable'] = int(summary_str.split('Trainable params: ')[1].split('\n')[0].replace(',',''))
        except: 
            del model 
            torch.cuda.empty_cache()
            print('NOT ABLE TO SCALE TO 1024')
            print(nm)

        if i <= 256: 
            n_params, gmac = my_summary(model, i, i, 3, 2)
            models_dict[nm]['n_params'] = int(n_params)
            models_dict[nm]['gmac'] = float(gmac)

        if 'T' in models_dict[nm]['MACs']: 
            macs = float(models_dict[nm]['MACs'].split(': ')[1])*1000
            macs_dict[md].append(macs)
        elif 'G' in models_dict[nm]['MACs']: 
            macs = float(models_dict[nm]['MACs'].split(': ')[1])
            macs_dict[md].append(macs)
        else:
            print('error')
        
    del model

In [None]:
model = model_generator(md).cuda()
my_summary(model, i, i, 3, 2)

In [None]:
models_dict['hscnn_plus_256']['n_params']

In [None]:
models_dict['mst_plus_plus_1024']['n_params']

In [None]:
def plot_macs(inputs, macs_dict): 
    plt.figure(figsize=(10,10))
    plt.ylabel('MACs', fontsize=16)
    plt.xlabel('Input Size (px)', fontsize=16)
    for k in macs_dict: 
        plt.plot(np.array(macs_dict[k]),np.array(inputs))
        
    plt.legend()
    plt.show()

In [None]:
list(macs_dict.keys())

In [None]:
plt.figure(figsize=(10,10))
plt.ylabel('MACs (G)', fontsize=16)
plt.xlabel('Input Size (px)', fontsize=16)
for k in macs_dict: 
    plt.plot(np.array(inputs),np.array(macs_dict[k]), 'o-')
    
plt.legend(list(macs_dict.keys()))
plt.show()

In [None]:
save_path = '/mnt/datassd/mst_toolbox/mst-vitmstpp_ntire/'
plt.figure(figsize=(10,8))
plt.ylabel('Log(MACs (G))', fontsize=18)
plt.xlabel('Input Size (px)', fontsize=18)
for k in macs_dict: 
    if k == 'vitmstpp_pad': 
        continue
    else: 
        plt.plot(np.array(inputs),np.log(np.array(macs_dict[k])), 'o-')
    
plt.legend(list(macs_dict.keys()), fontsize = 14)
plt.xticks(inputs, fontsize=12)
plt.title('Architecture scaling')
plt.savefig(save_path + 'figure_logmacs_no-pad.png', dpi=800)
plt.show()

In [None]:
models_dict

In [None]:
del model

In [None]:
md = 'restormer'
model = model_generator(md).cuda()
summary_str = str(torchinfo.summary(model, input_size=(1,3,1024,1024)))
del model


#models_dict[nm]['MACs'] = summary_str.split('Total mult-adds ')[1].split('\n')[0]
#models_dict[nm]['tot_size'] = summary_str.split('Estimated Total Size ')[1].split('\n')[0]
#models_dict[nm]['trainable'] = int(summary_str.split('Trainable params: ')[1].split('\n')[0].replace(',',''))

In [None]:
del model

In [None]:
for md in models: 
    print(md)
    model = model_generator(md).cuda()
    print('TORCHINFO')
    print(torchinfo.summary(model, input_size=(2,3,256,256)))
    #print('MY SUMMARY')
    #print(my_summary(model, 256, 256, 3, 2))

In [None]:
batch_size = 2

In [None]:
torchinfo.summary(model, input_size=(2,3,256,256)).keys()

In [None]:
torchinfo.summary(model, input_size=(1,3,256,256))

In [None]:
my_summary(model, 256, 256, 31, 1)

In [None]:
torchinfo.summary(model, input)