# Librerie e Funzioni

In [None]:
import torch
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader, TensorDataset, Dataset
from torch import optim, nn
from torch.nn import functional
import os
import time
import csv
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import shutil
from torchvision.utils import make_grid
from random import randint
from PIL import Image
import random

from einops import rearrange
from einops.layers.torch import Rearrange

import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from vit_pytorch import SimpleViT as ViT

from TRAM import TRAM
from PatchMergerViT import PatchMergerViT


from tqdm import tqdm

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

from thop import profile

import torch
import time

# Functions

In [None]:
def benchmark_time(
    model,
    device,
    input_size,
    batch_size,
    runs,
    throw_out,
    verbose,
):
    """
    Benchmark the given model with random inputs at the given batch size.

    Args:
     - model: the module to benchmark
     - device: the device to use for benchmarking
     - input_size: the input size to pass to the model (channels, h, w)
     - batch_size: the batch size to use for evaluation
     - runs: the number of total runs to do
     - throw_out: the percentage of runs to throw out at the start of testing
     - verbose: whether or not to use tqdm to print progress / print throughput at end

    Returns:
     - the throughput measured in images / second
    """
    if not isinstance(device, torch.device):
        device = torch.device(device)
    is_cuda = torch.device(device).type == "cuda"

    model = model.eval().to(device)
    print(device)
    input = torch.rand(batch_size, *input_size, device=device)

    warm_up = int(runs * throw_out)
    total = 0
    start = time.time()

    with torch.autocast(device.type):
        with torch.no_grad():
            for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
                if i == warm_up:
                    if is_cuda:
                        torch.cuda.synchronize()
                    total = 0
                    start = time.time()

                model(input)
                total += batch_size

    if is_cuda:
        torch.cuda.synchronize()

    end = time.time()
    elapsed = end - start

    throughput = total / elapsed
    throughput = f'{throughput:.2f}'

    if verbose:
        print(f"Throughput: {throughput:.2f} im/s")

    return [throughput]





def benchmark_flops(
    model,
    device,
    input_size,
    batch_size,
    runs,
    verbose,
):
    """
    Benchmark the given model with random inputs at the given batch size.

    Args:
     - model: the module to benchmark
     - device: the device to use for benchmarking
     - input_size: the input size to pass to the model (channels, h, w)
     - batch_size: the batch size to use for evaluation
     - runs: the number of total runs to do
     - verbose: whether or not to use tqdm to print progress / print throughput at end

    Returns:
     - the throughput measured in images / second
    """
    if not isinstance(device, torch.device):
        device = torch.device(device)
    is_cuda = torch.device(device).type == "cuda"

    model = model.eval().to(device)
    print(device)
    input = torch.rand(batch_size, *input_size, device=device)

    total = 0
    total_flops = 0
    total_params = 0

    with torch.autocast(device.type):
        with torch.no_grad():
            for i in tqdm(range(runs), disable=not verbose, desc="Benchmarking"):
                flops, params = profile(model, inputs=(input, ), verbose=False)
                total += batch_size
                total_flops += flops

    flops = total_flops/total
    flops = f'{flops/1e9:.4f}'
    if verbose:
        print(f"Total FLOPs: {flops / 1e9:.4f} GFLOPs")
        print(f"Total Parameters: {params / 1e6:.4f} M")

    return [flops], [params]



def create_patch_list(total_patches, cut):
    # Calcola l'importo scontato per ogni blocco di 3 elementi
    discounted_patches = total_patches
    patch_list = []
    for i in range(12):
        if i % 3 == 0 and i != 0:
            discounted_patches = int(discounted_patches * (cut / 100))
        patch_list.append(discounted_patches)
    return patch_list

In [None]:
def get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, type, df, device):

    # ViT
    print('##### ViT ####')
    ViTnet = ViT(
        image_size = img_size,
        patch_size = patch_size,
        num_classes = 10,
        dim = att_dim,
        depth = depth,
        heads = heads,
        mlp_dim = mlp_dim,
    )

    ViTnet.to(device)

    baseline_throughput = benchmark_time(
        ViTnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size,
        throw_out = 0.25
    )

    baseline_flops, baseline_params = benchmark_flops(
        ViTnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size
    )

    df = pd.concat([df, pd.DataFrame({'Modello': f'ViT_net_{type}', 'Throughput': baseline_throughput, 'Flops': baseline_flops, 'Params': baseline_params})], ignore_index=True)

    
    # TRAM
    print('##### TRAM ####')
    TRAMnet = TRAM(
        image_size = img_size,
        patch_size = patch_size,
        num_classes = 10,
        dim = att_dim,
        depth = depth,
        heads = heads,
        mlp_dim = mlp_dim,
        n_patch = max_tokens_per_depth
    )

    TRAMnet.to(device)


    TRAM_throughput = benchmark_time(
        TRAMnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size,
        throw_out = 0.25
    )

    TRAM_flops, TRAM_params = benchmark_flops(
        TRAMnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size
    )

    df = pd.concat([df, pd.DataFrame({'Modello': f'TRAM_{type}', 'Throughput': TRAM_throughput, 'Flops': TRAM_flops, 'Params': TRAM_params})], ignore_index=True)


    # Patch merger
    print('##### PatchMerger ####')
    PatchMergerViTnet = PatchMergerViT(
        image_size = img_size,
        patch_size = patch_size,
        num_classes = 10,
        dim = att_dim,
        depth = depth,
        heads = heads,
        mlp_dim = mlp_dim,
        patch_merge_layers = patch_merge_layers
    )

    PatchMergerViTnet.to(device)


    PatchMerger_throughput = benchmark_time(
        PatchMergerViTnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size,
        throw_out = 0.25
    )


    PatchMerger_flops, PatchMerger_params = benchmark_flops(
        PatchMergerViTnet,
        device=device,
        verbose=False,
        runs=runs_time,
        batch_size=batch_size,
        input_size=input_size
    )


    df = pd.concat([df, pd.DataFrame({'Modello': f'PatchMerger_{type}', 'Throughput': PatchMerger_throughput, 'Flops': PatchMerger_flops, 'Params': PatchMerger_params})], ignore_index=True)

    return df

# Benchmark

In [None]:
batch_size = 64
patch_size = 16
runs_time = 10
runs_flops = 10
input_size = (3, 160, 160)
img_size = 160
total_patches = int((img_size/patch_size)**2)


device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')


cut = 50
max_tokens_per_depth = create_patch_list(total_patches, cut)
patch_merge_layers = [(2, max_tokens_per_depth[3]),(5, max_tokens_per_depth[6]),(8, max_tokens_per_depth[9])] 

In [None]:
att_dim = 768
depth = 12
heads = 12
mlp_dim = att_dim * 4

type = 'Base'

df_gpu = pd.DataFrame(columns = ['Modello', 'Throughput', 'Flops', 'Params'])

In [None]:
df_gpu = get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, type, df_gpu, device)

In [None]:
att_dim = 384
depth = 12
heads = 6
mlp_dim = att_dim * 4

type = 'Small'

df_small_gpu = pd.DataFrame(columns = ['Modello', 'Throughput', 'Flops', 'Params'])

In [None]:
df_small_gpu = get_results(batch_size, patch_size, runs_time, runs_flops, input_size, img_size, att_dim, depth, heads, mlp_dim, max_tokens_per_depth, patch_merge_layers, type, df_small_gpu, device)

### Display

In [None]:
display(df_gpu), display(df_small_gpu)