# Imports

In [1]:
import warnings
import os
import glob
import pickle
import torch
import lightning.pytorch as pl
import numpy as np
import sys
import timm
from argparse import ArgumentParser
from collections import defaultdict
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch import trainer
import torchvision
from torchview import draw_graph

import torchvision.models as models
import torch
from ptflops import get_model_complexity_info


In [2]:
# UNIX
working_dir = "/Users/andry/Documents/GitHub/lus-dl-framework"
dataset_h5_path = "data/iclus/dataset.h5"
hospitaldict_path = "data/iclus/hospitals-patients-dict.pkl"

# Windows
# working_dir = "../."
# dataset_h5_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/dataset.h5"
# hospitaldict_path = "/Users/andry/Documents/GitHub/lus-dl-framework/data/iclus/hospitals-patients-dict.pkl"
libraries_dir = working_dir + "/libraries"


sys.path.append(working_dir)
os.chdir(working_dir)
os.getcwd()

from utils import *
from callbacks import *
from run_model import *
from get_sets import get_sets, get_class_weights

import sys
sys.path.append(working_dir)
from data_setup import HDF5Dataset, FrameTargetDataset, split_dataset, reduce_sets
from lightning_modules.LUSModelLightningModule import LUSModelLightningModule
from lightning_modules.LUSDataModule import LUSDataModule


# Args

In [3]:
from argparse import ArgumentParser
import json
import sys

# ------------------------------ Parse arguments ----------------------------- #
def parse_arguments(args=None):
    if args is None:
        # If no arguments are provided, use the command-line arguments of the notebook
        args = sys.argv[1:]

    # Parse command-line arguments
    parser = ArgumentParser()

    allowed_models = ["google_vit", 
                      "swin_tiny",
                      "swin_custom",
                    "resnet18",
                    "resnet10",
                    "resnet50",
                    "beit", 
                    'timm_bot', 
                    "botnet18", 
                    "botnet50",
                    "vit",
                    "swin_vit",
                    "simple_vit"]

    allowed_modes = ["train", "test", "train_test", "tune"]
    parser.add_argument("--model", type=str, choices=allowed_models)
    parser.add_argument("--mode", type=str, choices=allowed_modes)
    parser.add_argument("--version", type=str)
    parser.add_argument("--working_dir_path", type=str)
    parser.add_argument("--dataset_h5_path", type=str)
    parser.add_argument("--hospitaldict_path", type=str)
    
    parser.add_argument("--trim_data", type=float)
    parser.add_argument("--trim_train", type=float)
    parser.add_argument("--trim_test", type=float)
    parser.add_argument("--trim_val", type=float)
    
    parser.add_argument('--ratios', nargs='+', type=float, help='Sets ratios')


    parser.add_argument("--chkp", type=str)
    parser.add_argument("--rseed", type=int)
    parser.add_argument("--train_ratio", type=float)
    parser.add_argument("--batch_size", type=int, default=16)
    parser.add_argument("--optimizer", type=str, default="sgd")
    parser.add_argument("--lr", type=float, default=0.001)
    parser.add_argument("--weight_decay", type=float, default=0.001)
    parser.add_argument("--momentum", type=float, default=0.001)
    parser.add_argument("--label_smoothing", type=float, default=0.1)
    parser.add_argument("--drop_rate", type=float, default=0.1)
    parser.add_argument("--max_epochs", type=int, default=1)
    parser.add_argument("--num_workers", type=int, default=4)
    parser.add_argument("--accumulate_grad_batches", type=int, default=4)
    parser.add_argument("--precision", default=32)
    parser.add_argument("--disable_warnings", dest="disable_warnings", action='store_true')
    parser.add_argument("--pretrained", dest="pretrained", action='store_true')
    parser.add_argument("--freeze_layers", type=str)
    parser.add_argument("--test", dest="test", action='store_true')
    parser.add_argument("--mixup", dest="mixup", action='store_true')
    parser.add_argument("--augmentation", dest="augmentation", action='store_true')
    parser.add_argument("--summary", dest="summary", action='store_true')
    

    # Add an argument for the configuration file
    parser.add_argument('--config', type=str, help='Path to JSON configuration file')

    args = parser.parse_args(args)

    # -------------------------------- json config ------------------------------- #

    config_path = 'configs/configs.json'
    selected_config = None
    # If a configuration file was provided, load it
    if args.config:
        with open(config_path, 'r') as f:
            configurations = json.load(f)
        for config in configurations:
            if config['config'] == args.config:
                selected_config = config
                break

        # Override the command-line arguments with the configuration file
        for key, value in selected_config.items():
            if hasattr(args, key):
                setattr(args, key, value)
                
        # Check and set the ratios
        if "ratios" in selected_config:
            ratios = selected_config["ratios"]
            if len(ratios) != 3 or sum(ratios) != 1:
                parser.error('Invalid ratios provided in the configuration file')
            
        
    print(f"args are: {args}")

    return args


In [4]:
args_list = ['--model', 'swin_custom', 
             '--mode', 'train', 
             '--batch_size', '64',
             '--dataset_h5_path', f'{dataset_h5_path}',
             '--hospitaldict_path', f'{hospitaldict_path}',
             '--working_dir_path', f'{working_dir}',
             '--ratios', '0.6', '0.2', '0.2',
             '--rseed', '418',
             '--optimizer', 'adamw',
             '--lr', '0.00002',
             '--weight_decay', '0.0001',
             '--momentum', '0.9',
             '--label_smoothing', '0',
             '--drop_rate', '0.3',
             '--max_epochs', '10',
             '--num_workers', '0',
             '--accumulate_grad_batches', '1',
             '--precision', '32',
             '--disable_warnings',
            #  '--mixup',
             '--augmentation',
            #  '--summary',
             '--pretrained',
            #  '--freeze_layers', '0',
             
             ]

In [5]:
args = parse_arguments(args_list)



In [6]:
import warnings

pl.seed_everything(args.rseed)
if args.disable_warnings: 
    print("Warnings are DISABLED!\n\n")
    warnings.filterwarnings("ignore")
else:
    warnings.filterwarnings("default")

Global seed set to 418






### Pretrained
- resnet50
- botnet50ts
- swin_tiny
- focal_tiny

### Not pretrained
- resnet18 
- botnet26t
- swin_micro
- focal_micro

In [7]:
def print_layers_req_grad(model):
    # Print all layers and their requires_grad status
    for name, param in model.named_parameters():
        print(f'Parameter: {name}, Requires Gradient: {param.requires_grad}')
        
def print_number_of_parameters(model):
    print(f"Total params: {sum(p.numel() for p in model.parameters())}")
    print(f"Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")
    
def parameters_difference(model1, model2):
    n_parameters_model1 = sum(p.numel() for p in model1.parameters())
    n_parameters_model2 = sum(p.numel() for p in model2.parameters())
    return n_parameters_model1 - n_parameters_model2

# Search models

In [49]:
timm.list_models("*attention*", pretrained=True)

[]

In [51]:
model = timm.create_model(f'deit_tiny_patch16_224.fb_in1k',
                                num_classes=4,
                                drop_rate=args.drop_rate)

print("resnet50 parameters:")
print_number_of_parameters(resnet50)
print("\n model parameters:")
print_number_of_parameters(model)
print(f"\n difference in number of parameters: {parameters_difference(model, resnet18)}")

resnet50 parameters:
Total params: 23516228
Trainable params: 23516228

 model parameters:
Total params: 5525188
Trainable params: 5525188


NameError: name 'resnet18' is not defined

In [None]:
model_graph = draw_graph(model, 
                        #  graph_dir='LR',
                         input_size=(1,3,224,224), 
                         depth=3,
                         show_shapes=False,
                         expand_nested=False, 
                         device='meta')
# model_graph.visual_graph

NameError: name 'model' is not defined

# Micro models

### Resnet18 CNN

In [53]:
resnet18 = timm.create_model(f'resnet18.a1_in1k',
                            num_classes=4,
                            drop_rate=args.drop_rate)

### pure ViT

In [54]:
vit_tiny = timm.create_model(f'vit_tiny_patch16_224',
                                num_classes=4,
                                drop_rate=args.drop_rate)

print("resnet18 parameters:")
print_number_of_parameters(resnet18)
print("\n eff_vit parameters:")
print_number_of_parameters(vit_tiny)
print(f"\n difference in number of parameters: {parameters_difference(vit_tiny, resnet18)}")

resnet18 parameters:
Total params: 11178564
Trainable params: 11178564

 eff_vit parameters:
Total params: 5525188
Trainable params: 5525188

 difference in number of parameters: -5653376


### Botnet26ts Hybrid

In [55]:
timm.list_models("*deit*")

['deit3_base_patch16_224',
 'deit3_base_patch16_384',
 'deit3_huge_patch14_224',
 'deit3_large_patch16_224',
 'deit3_large_patch16_384',
 'deit3_medium_patch16_224',
 'deit3_small_patch16_224',
 'deit3_small_patch16_384',
 'deit_base_distilled_patch16_224',
 'deit_base_distilled_patch16_384',
 'deit_base_patch16_224',
 'deit_base_patch16_384',
 'deit_small_distilled_patch16_224',
 'deit_small_patch16_224',
 'deit_tiny_distilled_patch16_224',
 'deit_tiny_patch16_224']

In [56]:
botnet26t = timm.create_model(f'botnet26t_256',
                               img_size=224,
                               fixed_input_size=True,
                               num_classes=4,
                               drop_rate=args.drop_rate)

print("resnet18 parameters:")
print_number_of_parameters(resnet18)
print("botnet26t parameters:")
print_number_of_parameters(botnet26t)
print(f"\n difference in number of parameters: {parameters_difference(botnet26t, resnet18)}")

resnet18 parameters:
Total params: 11178564
Trainable params: 11178564
botnet26t parameters:
Total params: 10445820
Trainable params: 10445820

 difference in number of parameters: -732744


# Tiny Models

### ResNet50 CNN

In [8]:
resnet50 = timm.create_model(f'resnet50.a1_in1k',
                            num_classes=4,
                            drop_rate=args.drop_rate)

### Swin Tiny Pure ViT

In [9]:
swin_tiny = timm.create_model('swin_tiny_patch4_window7_224',
                                # embed_dim = 64,
                                # depths = (2, 2, 6, 2),
                                # num_heads=(4, 8, 16, 32),
                                num_classes=4)

print("resnet50 parameters:")
print_number_of_parameters(resnet50)
print("\nswin_tiny parameters:")
print_number_of_parameters(swin_tiny)
print(f"\n difference in number of parameters: {parameters_difference(swin_tiny, resnet50)}")

resnet50 parameters:
Total params: 23516228
Trainable params: 23516228

swin_tiny parameters:
Total params: 27522430
Trainable params: 27522430

 difference in number of parameters: 4006202


In [59]:
# with torch.device('):
#   net = swin_tiny
#   flops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
#                                            print_per_layer_stat=True, verbose=True)
#   print('{:<30}  {:<8}'.format('Computational complexity: ', flops))
#   print('{:<30}  {:<8}'.format('Number of parameters: ', params))

### BotNet50ts Hybrid

In [10]:
from lightning_modules.BotNet18LightningModule import BotNet

In [12]:
botnet50 = BotNet("bottleneck",
                    [3, 4, 6, 3], 
                    num_classes=4, 
                    resolution=(224, 224), 
                    heads=4,
                    drop_rate=0.1)
print("botnet50 parameters:")
print_number_of_parameters(botnet50)
print(f"\n difference in number of parameters: {parameters_difference(botnet50, resnet50)}")

botnet50 parameters:
Total params: 18849092
Trainable params: 18849092

 difference in number of parameters: -4667136


In [60]:
botnet50ts = timm.create_model(f'botnet50ts_256',
                               img_size=224,
                               fixed_input_size=True,
                               num_classes=4,
                               drop_rate=args.drop_rate)

print("resnet50 parameters:")
print_number_of_parameters(resnet50)
print("botnet50ts parameters:")
print_number_of_parameters(botnet50ts)
print(f"\n difference in number of parameters: {parameters_difference(botnet50ts, resnet50)}")

resnet50 parameters:
Total params: 23516228
Trainable params: 23516228
botnet50ts parameters:
Total params: 20699324
Trainable params: 20699324

 difference in number of parameters: -2816904


# Small models

### ResNet50 CNN

In [61]:
resnet101 = timm.create_model(f'resnet101.a1_in1k',
                            num_classes=4,
                            drop_rate=args.drop_rate)

### Swin Tiny Pure ViT

In [62]:
swin_small = timm.create_model('swin_small_patch4_window7_224',
                                embed_dim = 48,
                                depths=(2, 2, 6, 2),
                                num_heads=(3, 6, 12, 24),
                                num_classes=4)

print("resnet18 parameters:")
print_number_of_parameters(resnet18)
print("\nswin_small parameters:")
print_number_of_parameters(swin_small)
print(f"\n difference in number of parameters: {parameters_difference(swin_small, resnet18)}")

resnet18 parameters:
Total params: 11178564
Trainable params: 11178564

swin_small parameters:
Total params: 6916174
Trainable params: 6916174

 difference in number of parameters: -4262390
