In [1]:
import torch 
import time as t
from model.UNet import UNet
from model.AttentionUNet import AttentionUNet 
from model.UNetTransformer import UNetTransformer
from model.PreTrainedUNetSWCA import PreTrainedUNetSWCA

In [2]:
device = torch.device('cuda')

def getModel(model):
    model = model.lower()
    
    if model == 'unet':
        model_dict = {
            'eff_b1': UNet(backbone='eff_b1'),
            'eff_b5': UNet(backbone='eff_b5'),
            'res_50': UNet(backbone='res_50'),
        }
    elif model == 'attn': 
        model_dict = {
            'eff_b1': AttentionUNet(backbone='eff_b1'),
            'eff_b5': AttentionUNet(backbone='eff_b5'),
            'res_50': AttentionUNet(backbone='res_50'),
        }
    elif model == 'utrans':
        model_dict = {
            'eff_b1': UNetTransformer(backbone='eff_b1'),
            'eff_b5': UNetTransformer(backbone='eff_b5'),
            'res_50': UNetTransformer(backbone='res_50'),
        }
    elif model == 'swca':
        model_dict = {
            'eff_b1': PreTrainedUNetSWCA(
                device='cuda', 
                backbone_name='eff_b1', 
                bottleneck_head=8,
                window_sizes=[8, 8, 8, 8], 
                layers=[2, 2, 4, 4], 
                qkv_bias=False, 
                attn_drop_prob=0.2, 
                lin_drop_prob=0.1),
            'eff_b5': PreTrainedUNetSWCA(
                device='cuda', 
                backbone_name='eff_b5', 
                bottleneck_head=32,
                window_sizes=[5, 5, 5, 5], 
                layers=[2, 2, 4, 4], 
                qkv_bias=False, 
                attn_drop_prob=0.2, 
                lin_drop_prob=0.1),
            'res_50': PreTrainedUNetSWCA(
                device='cuda', 
                backbone_name='res_50', 
                bottleneck_head=8,
                window_sizes=[8, 8, 8, 8], 
                layers=[2, 2, 4, 4], 
                qkv_bias=False, 
                attn_drop_prob=0.2, 
                lin_drop_prob=0.1),
        }
    else:
        print('model tidak tersedia')
        return 0 
    
    return model_dict

In [3]:
# unet, attn, utrans, swca
models = getModel('swca')

model_effb1, model_effb5, model_effres50 = models['eff_b1'], models['eff_b5'], models['res_50']

Downloading: "https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth" to C:\Users\User/.cache\torch\hub\checkpoints\efficientnet_b3_rwightman-cf984f9c.pth
100%|██████████| 47.2M/47.2M [01:53<00:00, 438kB/s] 
Downloading: "https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth" to C:\Users\User/.cache\torch\hub\checkpoints\efficientnet_b6_lukemelas-c76e70fd.pth
100%|██████████| 165M/165M [04:48<00:00, 602kB/s]  
Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to C:\Users\User/.cache\torch\hub\checkpoints\densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [01:05<00:00, 492kB/s]
Downloading: "https://download.pytorch.org/models/densenet169-b2777c0a.pth" to C:\Users\User/.cache\torch\hub\checkpoints\densenet169-b2777c0a.pth
100%|██████████| 54.7M/54.7M [01:47<00:00, 531kB/s]
Downloading: "https://download.pytorch.org/models/densenet201-c1103571.pth" to C:\Users\User/.cache\torch\hub\checkpoints\densenet201-c1103571.pth


In [4]:
# Todo: Params check 
# from torchsummary import summary 
# summary(model_effres50.to("cuda"), (3, 256, 256))

# Todo: GFLOPS Check
# from thop import profile
# import warnings
# warnings.filterwarnings("ignore", category=DeprecationWarning) 

# sample = torch.randn((1, 3, 480, 640)).to(device)
# macs, params = profile(model_effb5.to(device), inputs=(sample, ))
# print(f"GFLOPs: {2*macs * 1e-9:.2f}, Params: {params:.2f}")

# sample = torch.randn((1, 3, 256, 256)).to(device)Q
# macs, params = profile(model_effb1.to(device), inputs=(sample, ))
# print(f"GFLOPs: {2*macs * 1e-9:.2f}, Params: {params:.2f}")

# macs, params = profile(model_effres50.to(device), inputs=(sample, ))
# print(f"GFLOPs: {2*macs * 1e-9:.2f}, Params: {params:.2f}")


In [5]:
fps_counter = {}
inference_counter = {}

i = 0

for model, name in zip([model_effb5, model_effb1, model_effres50], ['eff_b5', 'eff_b1', 'res_50']):
    print(name)
    model.to(device)
    if i == 0:
        random_sample = torch.randn((1, 3, 480, 640)).to(device)
    else:
        random_sample = torch.randn((1, 3, 256, 256)).to(device)
    
    # Todo: Warm up for 10 sec
    print('WARMUP')
    warmup_start = t.time()
    while True: 
        pred = model(random_sample)
        
        if (t.time() - warmup_start) >= 10:
            break
    
    print('Inference time experiment start')
    prev_frame_time = 0
    new_frame_time = 0

    fps_history, inference_history = [], []
    fps_inference_time_counter = t.time()
    while True: 
        inference_start = t.time()
        pred = model(random_sample)
        inference_end = t.time() - inference_start
        
        new_frame_time = t.time() 
        fps = int(1/(new_frame_time-prev_frame_time))
        prev_frame_time = new_frame_time 
        
        fps_history.append(fps)
        inference_history.append(inference_end)
        
        if (t.time() - fps_inference_time_counter) >= 60.:
            break
    
    fps_counter[name] = fps_history
    inference_counter[name] = inference_history
    
    i += 1

print('DONE')

eff_b5
WARMUP
Inference time experiment start
eff_b1
WARMUP
Inference time experiment start
res_50
WARMUP
Inference time experiment start
DONE


In [6]:
import numpy as np 

for key in ['eff_b5', 'eff_b1', 'res_50']: 
    print(f"\t{key}")
    median_fps = np.median(np.array(fps_counter[key]))
    median_inference = np.median(np.array(inference_counter[key]))
    
    mean_fps = np.array(fps_counter[key]).mean()
    mean_inference = np.array(inference_counter[key]).mean()
    
    print(f"\tMedian FPS : {median_fps:.2f}")
    print(f"\tMean FPS : {mean_fps:.2f}\n")
    print(f"\tMedian Inference Time : {median_inference * 1000:.2f} ms")
    print(f"\tMean Inference Time : {mean_inference * 1000:.2f} ms")
    print("---"*20)

	eff_b5
	Median FPS : 1.00
	Mean FPS : 1.08

	Median Inference Time : 649.94 ms
	Mean Inference Time : 651.53 ms
------------------------------------------------------------
	eff_b1
	Median FPS : 9.00
	Mean FPS : 9.39

	Median Inference Time : 102.89 ms
	Mean Inference Time : 102.99 ms
------------------------------------------------------------
	res_50
	Median FPS : 5.00
	Mean FPS : 5.27

	Median Inference Time : 172.50 ms
	Mean Inference Time : 173.68 ms
------------------------------------------------------------


In [5]:
text = print(fps_counter)

{'eff_b1': [0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 3, 2, 1, 2, 2, 2, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 1, 2, 2, 3, 2, 2, 2, 2, 2, 2, 1, 2, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2, 2, 1], 'eff_b5': [0, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'res_50': [0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1

In [6]:
print(inference_counter)

{'eff_b1': [0.3541300296783447, 0.3659789562225342, 0.4429311752319336, 0.44008398056030273, 0.3707888126373291, 0.35350751876831055, 0.3901689052581787, 0.49134325981140137, 0.42220234870910645, 0.3678269386291504, 0.33605074882507324, 0.37355732917785645, 0.39514946937561035, 0.38775205612182617, 0.3827958106994629, 0.47191810607910156, 0.4002096652984619, 0.4470503330230713, 0.35376858711242676, 0.41971516609191895, 0.4315376281738281, 0.41928577423095703, 0.4111814498901367, 0.4188542366027832, 0.3571195602416992, 0.3790731430053711, 0.3622558116912842, 0.4354875087738037, 0.4077119827270508, 0.3890554904937744, 0.4358656406402588, 0.4286673069000244, 0.34864306449890137, 0.3975381851196289, 0.4259822368621826, 0.39708590507507324, 0.36244726181030273, 0.3510870933532715, 0.455951452255249, 0.4703195095062256, 0.4074070453643799, 0.4669816493988037, 0.4410727024078369, 0.4424281120300293, 0.43280839920043945, 0.42140913009643555, 0.43094491958618164, 0.41110944747924805, 0.41684222