# Reference 

Plese install the following packages first:
```
pip install timm fvcore==0.1.5
```

In [None]:
import numpy as np 
from numbers import Number
from typing import Any, Callable, List, Optional, Union
from fvcore.nn import FlopCountAnalysis, flop_count_table

import torch 
from spanet import spanet_small, spanet_medium, spanet_mediumX, spanet_base, spanet_baseX

In [None]:
def rfft_flop_jit(inputs: List[Any], outputs: List[Any]) -> Number:
    """
    Count flops for the rfft/rfftn operator.
    """
    input_shape = inputs[0].type().sizes()
    B, H, W, C = input_shape
    N = H * W
    flops = N * C * np.ceil(np.log2(N))
    return flops

def calc_ofcnet_flops(model, img_size=224, show_details=False):
    with torch.no_grad():
        x = torch.randn(1, 3, img_size, img_size).cuda()
        fca1 = FlopCountAnalysis(model, x)
        handlers = {
            'aten::fft_rfft2': rfft_flop_jit,
            'aten::fft_irfft2': rfft_flop_jit,
        }
        fca1.set_op_handle(**handlers) # 이건 뭔데 추가된 걸까? 
        flops1 = fca1.total()
        if show_details:
            print(fca1.by_module())
        print("#### GFLOPs: {}".format(flops1 / 1e9))
    return flops1 / 1e9

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = spanet_baseX().to(device)
model.eval()

In [None]:
image_size = [224, 224]    
dummy_input = torch.rand(1, 3, *image_size).to(device)

In [None]:
# == Simple way == # 
# Please note that FLOP here actually means MAC.
flop = FlopCountAnalysis(model, dummy_input)
print(flop_count_table(flop, max_depth=4))
print('MACs (G):', flop.total()/1e9)

In [None]:
# INIT LOGGERS
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
repetitions = 300
timings=np.zeros((repetitions,1))

#GPU-WARM-UP
for _ in range(10):
    _ = model(dummy_input)

# MEASURE PERFORMANCE
with torch.no_grad():
    for rep in range(repetitions):
        starter.record()
        _ = model(dummy_input)
        ender.record()
        # WAIT FOR GPU SYNC
        torch.cuda.synchronize()
        curr_time = starter.elapsed_time(ender)
        timings[rep] = curr_time
mean_syn = np.sum(timings) / repetitions
std_syn = np.std(timings)

print(f"mean: {mean_syn:.03} ms")
print(f"std: {std_syn}")

# MEASURE MACs 
# Please note that FLOPs here actually means MACs.    
with torch.no_grad():
    flops = FlopCountAnalysis(model, dummy_input)
    handlers = {
                'aten::fft_rfft2': rfft_flop_jit,
                'aten::fft_irfft2': rfft_flop_jit,
            }
    flops.set_op_handle(**handlers) 
        
    if False:
        # show_details
        print(flops.by_module())
        
    print(flop_count_table(flops))
    print(f'MACs (G): {flops.total()/1e9:0.3f}G' ) 