## Study the Time Complexity of U-Net and Attention Unet under Different Configurations.

In [None]:
import torch
import time
import matplotlib.pyplot as plt
from models.unet import UNet
from models.attn_unet import AttentionUnet

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  # use GPU if available
#device = "cpu"

In [None]:
def measure_forward_time(model:torch.nn.Module, input_tensor:torch.Tensor):
    start_time = time.time()
    out = model(input_tensor)
    end_time = time.time()
    return end_time - start_time

In [None]:
def compare_models(unet, attn_unet, num_trials=100):
    # Prepare the input

    input_shape=(1, 38, 640, 368)
    # single coil data: ( batch, number of slices, height, width)
    input_tensor = torch.randn(input_shape).to(device)
    
    unet = unet.to(device)
    attn_unet = attn_unet.to(device)

    unet.eval()
    attn_unet.eval()
    
    times_unet = []
    times_attn_unet = []
    
    with torch.no_grad():
        # do not store gradients as it is run in inference mode...
        for _ in range(num_trials):
            unet_time = measure_forward_time(unet, input_tensor)
            attn_unet_time = measure_forward_time(att_unet, input_tensor)
            
            times_unet.append(unet_time)
            times_attn_unet.append(attn_unet_time)
            print("times: ", (unet_time, attn_unet_time))
    
    # Plotting
    plt.figure(figsize=(10, 7))
    plt.title('(Eval Mode) Forward Pass Time Comparison')
    plt.xlabel('Trial')
    plt.ylabel('Time (s)')
    plt.plot(times_unet[1:], 'r', label='UNet')
    plt.plot(times_attn_unet[1:], 'b', label='AttentionUNet')
    plt.legend()
    plt.show()

In [None]:
shared_hparams = dict(
    in_chans = 38,
    out_chans = 38,
    chans = 32,
    num_pool_layers = 4,
    drop_prob = 0.0,
)
unet = UNet(
    **shared_hparams
)
att_unet = AttentionUnet(
    **shared_hparams
)

In [12]:
print(unet)

UNet(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(38, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout

In [13]:
print(att_unet)

AttentionUnet(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(38, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3): Dropout2d(p=0.0, inplace=False)
        (4): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (5): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (6): LeakyReLU(negative_slope=0.2, inplace=True)
        (7): Dropout2d(p=0.0, inplace=False)
      )
    )
    (1): ConvBlock(
      (layers): Sequential(
        (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
        (2): LeakyReLU(negative_slope=0.2, inplace=True)
        (3)

In [None]:
compare_models(unet, att_unet, num_trials=100)

## Memory Complexity

In [None]:
def model_memory_required(model):
    # source: https://pytorch.org/docs/stable/index.html
    total_params = sum(p.numel() for p in model.parameters())
    return total_params * 4 / (1024 ** 2)  # return in size in MB

def compare_model_memory(unet, attn_unet):
    memory_unet = model_memory_required(unet)
    memory_attn_unet = model_memory_required(attn_unet)
    # x labels
    models = ['UNet', 'AttnUNet']
    memory = [memory_unet, memory_attn_unet]
    
    diff = memory_attn_unet - memory_unet 
    print(diff)
    plt.bar(models, memory, color=['red', 'blue'])
    plt.ylabel('Memory (MB)')
    plt.title('Parameter Memory Comparison Between UNet and AttnUnet')
    plt.show()


In [None]:
compare_model_memory(unet, att_unet)