In [1]:
import numpy as np
import torch
from architectures import *
from time import time
from  matplotlib import pyplot as plt
import timm
import monai
from tqdm import tqdm
import os
import math
import pickle
import sys

sys.path.append('/home/hoopersm/long_context_paper/FMImaging/model')
sys.path.append('/home/hoopersm/long_context_paper/FMImaging/model/imaging_attention/')
sys.path.append('/home/hoopersm/long_context_paper/FMImaging/model/backbone/')
sys.path.append('/home/hoopersm/long_context_paper/FMImaging')
sys.path.append('/home/hoopersm/long_context_paper/FMImaging/setup')
sys.path.append('/home/hoopersm/long_context_paper/FMImaging/setup/setup_base')

from backbone import *
from model import *
from model.backbone import *
from setup.setup_base import parse_config



ModuleNotFoundError: No module named 'torch'

## Helper functions

In [None]:
# Helper functions

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def time_run(input_shape, model, reps, padding=20):
    model.train()
    time_list = []
    for rep in tqdm(range(reps+padding)):
        start = time.time()
        out = model(input)
        end = time.time()
        time_list.append(end-start)
    print(f"\tAverage run time: {np.mean(time_list[padding:])} seconds")
    return time_list[padding:]
    

## ViT tests

Load timing dict to track results over multiple runs

In [None]:

if os.path.exists('/home/hoopersm/hyena/vit_speedup_dict.pkl'):
    with open('/home/hoopersm/hyena/vit_speedup_dict.pkl', 'rb') as f:
        speedup_dict = pickle.load(f)
    key_count = len(speedup_dict)
    print('Loaded speedup_dict with {} entries'.format(key_count))
else:
    speedup_dict = {}
    key_count = 0
    print('Initialized speedup_dict, no entries yet')

Create data

In [None]:
batch = 2
in_ch = 3
out_ch = 1
h = 1024
w = 1024
t = 1
reps = 100

vit_patch = (1,16,16) # (t,h,w)

input = torch.zeros(((batch,in_ch,t,h,w))).to(device='cuda') 


Compute context length for ViT

In [None]:
# Compute ViT context length
context_length = np.ceil(t/vit_patch[0]) * np.ceil(h/vit_patch[1]) * np.ceil(w/vit_patch[2])
print('ViT context length: {}'.format(context_length))

Runing timing tests

In [None]:
print('Creating attention-based ViT model...')
vit_attention_model = ViT(
                            use_hyena=False,
                            in_channels=in_ch,
                            num_classes=out_ch,
                            img_size=(t,h,w),
                            spatial_dims=3,
                            patch_size=vit_patch,
                            classification=True
                        ).to(device='cuda')
attn_num_params = count_parameters(vit_attention_model)
print(f"\tNumber of parameters with attention: {attn_num_params/1e6} million")

print('Running attention-based ViT timing test...')
attn_time = time_run(input, vit_attention_model, reps)

# Clean up model
vit_attention_model.to(device='cpu')
torch.cuda.empty_cache()
del vit_attention_model
torch.cuda.empty_cache()

: 

In [None]:
print('Creating hyena-based ViT model...')
vit_hyena_model = ViT(
                            use_hyena=True,
                            in_channels=in_ch,
                            num_classes=out_ch,
                            img_size=(t,h,w),
                            spatial_dims=3,
                            patch_size=vit_patch,
                            classification=True
                        ).to(device='cuda')
hyena_num_params = count_parameters(vit_hyena_model)
print(f"\tNumber of parameters with hyena: {hyena_num_params/1e6} million")

print('Running hyena-based ViT timing test...')
hyena_time = time_run(input, vit_hyena_model, reps)

# Clean up model
vit_hyena_model.to(device='cpu')
torch.cuda.empty_cache()
del vit_hyena_model
torch.cuda.empty_cache()

Print results

In [None]:
# Print speedup
speedup = [a/h for a, h in zip(attn_time, hyena_time)]
print(f"Average speedup: {np.mean(speedup)}x")
if np.abs(np.mean(speedup)-np.median(speedup)>0.1):
    print("\tWARNING: Mean and median are very different")
    print(f"\tMedian speedup: {np.median(speedup)}x")

# Plot speedup
plt.figure(figsize=(4,3))
plt.plot(speedup)
plt.xlabel("Repetition")
plt.ylabel("Speedup (attention/hyena)")  
plt.show()
plt.close()

# Plot raw timing data
plt.figure(figsize=(4,3))
plt.plot(attn_time,'r',label='Attention')
plt.plot(hyena_time,'b',label='Hyena')
plt.xlabel("Repetition")
plt.ylabel("Time (s)")  
plt.legend()
plt.show()
plt.close()





Save results into dict

In [None]:
speedup_dict[key_count] = {'image_size': (batch,in_ch,h,w,t), 
                            'context_length':context_length, 
                            'speedup':np.mean(speedup),
                            'attn_time':np.mean(attn_time),
                            'hyena_time':np.mean(hyena_time), 
                            'attn_num_param':attn_num_params,
                            'hyena_num_param':hyena_num_params}
key_count+=1
with open('/home/hoopersm/hyena/vit_speedup_dict.pkl', 'wb') as f:
    pickle.dump(speedup_dict, f, pickle.HIGHEST_PROTOCOL)

Plot results from all experiments

In [None]:
with open('/home/hoopersm/hyena/vit_speedup_dict.pkl', 'rb') as f:
    speedup_dict = pickle.load(f)
key_count = len(speedup_dict)

all_speedup = [speedup_dict[key]['speedup'] for key in speedup_dict.keys()]
all_context_length = [speedup_dict[key]['context_length'] for key in speedup_dict.keys()]

plt.figure(figsize=(4,3))
plt.plot(all_context_length,all_speedup,'rx')
plt.title("ViT speedup vs. context length")
plt.xlabel("Context length")
plt.ylabel("Speedup")
plt.show()
plt.close()

## Swin tests

Load timing dict to track results over multiple runs

In [None]:

if os.path.exists('/home/hoopersm/hyena/swin_speedup_dict.pkl'):
    with open('/home/hoopersm/hyena/swin_speedup_dict.pkl', 'rb') as f:
        speedup_dict = pickle.load(f)
    key_count = len(speedup_dict)
    print('Loaded speedup_dict with {} entries'.format(key_count))
else:
    speedup_dict = {}
    key_count = 0
    print('Initialized speedup_dict, no entries yet')

Create data

In [None]:
batch = 2
in_ch = 3
out_ch = 1
h = 224
w = 224
t = 32
reps = 100

swin_patch = (2,2,2) # (t,h,w)
swin_window = (8,8,8) # (t,h,w)

input = torch.zeros(((batch,in_ch,t,h,w))).to(device='cuda') 


Compute context length for Swin

In [None]:
# Compute max Swin context length
context_length = swin_window[0] * swin_window[1] * swin_window[2]
print('Max Swin context length: {}'.format(context_length))

Runing timing tests

In [None]:
print('Creating attention-based Swin model...')
                    
swin_attention_model = SwinTransformer(
                            use_hyena=False,
                            in_chans=in_ch,
                            embed_dim=48,
                            window_size=swin_window,
                            patch_size=swin_patch,
                            depths=[2,2,6,2],
                            num_heads=[3,6,12,24],
                            spatial_dims=3,
                        ).to(device='cuda')
attn_num_params = count_parameters(swin_attention_model)
print(f"\tNumber of parameters with attention: {attn_num_params/1e6} million")

print('Running attention-based Swin timing test...')
attn_time = time_run(input, swin_attention_model, reps)

# Clean up model
swin_attention_model.to(device='cpu')
torch.cuda.empty_cache()
del swin_attention_model
torch.cuda.empty_cache()

In [None]:
print('Creating hyena-based Swin model...')
                    
swin_hyena_model = SwinTransformer(
                            use_hyena=True,
                            in_chans=in_ch,
                            embed_dim=48,
                            window_size=swin_window,
                            patch_size=swin_patch,
                            depths=[2,2,6,2],
                            num_heads=[3,6,12,24],
                            spatial_dims=3,
                        ).to(device='cuda')
hyena_num_params = count_parameters(swin_hyena_model)
print(f"\tNumber of parameters with hyena: {hyena_num_params/1e6} million")

print('Running hyena-based Swin timing test...')
hyena_time = time_run(input, swin_hyena_model, reps)

# Clean up model
swin_hyena_model.to(device='cpu')
torch.cuda.empty_cache()
del swin_hyena_model
torch.cuda.empty_cache()

Print results

In [None]:
# Print speedup
speedup = [a/h for a, h in zip(attn_time, hyena_time)]
print(f"Average speedup: {np.mean(speedup)}x")
if np.abs(np.mean(speedup)-np.median(speedup)>0.1):
    print("\tWARNING: Mean and median are very different")
    print(f"\tMedian speedup: {np.median(speedup)}x")

# Plot speedup
plt.figure(figsize=(4,3))
plt.plot(speedup)
plt.xlabel("Repetition")
plt.ylabel("Speedup (attention/hyena)")  
plt.show()
plt.close()

# Plot raw timing data
plt.figure(figsize=(4,3))
plt.plot(attn_time,'r',label='Attention')
plt.plot(hyena_time,'b',label='Hyena')
plt.xlabel("Repetition")
plt.ylabel("Time (s)")  
plt.legend()
plt.show()
plt.close()



Save results into dict

In [None]:
speedup_dict[key_count] = {'image_size': (batch,in_ch,h,w,t), 
                            'context_length':context_length, 
                            'speedup':np.mean(speedup),
                            'attn_time':np.mean(attn_time),
                            'hyena_time':np.mean(hyena_time), 
                            'attn_num_param':attn_num_params,
                            'hyena_num_param':hyena_num_params}
key_count+=1
with open('/home/hoopersm/hyena/swin_speedup_dict.pkl', 'wb') as f:
    pickle.dump(speedup_dict, f, pickle.HIGHEST_PROTOCOL)

Plot results from all experiments

In [None]:
with open('/home/hoopersm/hyena/swin_speedup_dict.pkl', 'rb') as f:
    speedup_dict = pickle.load(f)
key_count = len(speedup_dict)

all_speedup = [speedup_dict[key]['speedup'] for key in speedup_dict.keys()]
all_context_length = [speedup_dict[key]['context_length'] for key in speedup_dict.keys()]

plt.figure(figsize=(4,3))
plt.plot(all_context_length,all_speedup,'rx')
plt.title("Swin speedup vs. context length")
plt.xlabel("Context length")
plt.ylabel("Speedup")
plt.show()
plt.close()

## STCNNT Tests

Load timing dict to track results over multiple runs

In [None]:
if os.path.exists('/home/hoopersm/hyena/stcnnt_speedup_dict.pkl'):
    with open('/home/hoopersm/hyena/stcnnt_speedup_dict.pkl', 'rb') as f:
        speedup_dict = pickle.load(f)
    key_count = len(speedup_dict)
    print('Loaded speedup_dict with {} entries'.format(key_count))
else:
    speedup_dict = {}
    key_count = 0
    print('Initialized speedup_dict, no entries yet')

Create data

In [None]:
batch = 2
in_ch = 3
out_ch = 1
h = 64
w = 64
t = 20
reps = 100

input = torch.zeros(((batch,in_ch,t,h,w))).to(device='cuda') 


Create config for STCNNT

In [None]:
config = parse_config()

# attention modules
config.kernel_size = 3
config.stride = 1
config.padding = 1
config.stride_t = 2
config.dropout_p = 0.1
config.no_in_channel = in_ch
config.C_out = out_ch
config.height = h
config.width = w
config.batch_size = batch
config.time = t
config.norm_mode = "instance2d"
config.a_type = "conv"
config.is_causal = False
config.n_head = 32
config.interp_align_c = True

config.window_size = [h//8, w//8]
config.patch_size = [h//32, w//32]

config.num_wind =[8, 8]
config.num_patch =[2, 2]

config.window_sizing_method = "mixed"

# losses
config.losses = ["mse"]
config.loss_weights = [1.0]
config.load_path = None

# to be tested
config.residual = True
config.device = None
config.channels = [16,32,64]
config.all_w_decay = True
config.optim = "adamw"
config.scheduler = "StepLR"

config.complex_i = False

config.summary_depth = 4

config.backbone_hrnet = Namespace()
config.backbone_hrnet.C = 32
config.backbone_hrnet.num_resolution_levels = 4

config.backbone_hrnet.use_interpolation = True

config.cell_type = "sequential"
config.normalize_Q_K = True 
config.att_dropout_p = 0.0
config.att_with_output_proj = True 
config.scale_ratio_in_mixer  = 1.0

config.cosine_att = True
config.att_with_relative_postion_bias = False

config.block_dense_connection = True

config.optim = "adamw"
config.scheduler = "ReduceLROnPlateau"
config.all_w_decay = True

config.device = 'cuda'

config.with_timer = True

config.stride_s = 1
config.separable_conv = True
config.use_einsum = False

config.mixer_kernel_size = 3
config.mixer_stride = 1
config.mixer_padding = 1

config.mixer_type = 'conv'
config.shuffle_in_window = False
config.temporal_flash_attention = False 
config.activation_func = 'prelu'

config.upsample_method = 'linear'

# ---------------------------------------------------------------------

config.dropout_p = 0.1

config.backbone_hrnet.block_str = ["T1L1G1",
                "T1L1G1",
                "T1L1G1",
                "T1L1G1",
                "T1L1G1"]

Compute context length for STCNNT

In [None]:
# # Compute max STCNNT context length
# context_length = ???
# print('Max STCNNT context length: {}'.format(context_length))

Runing timing tests

In [None]:
print('Creating attention-based STCNNT model...')
                    
hrnet_attention_model = STCNNT_HRnet(
                            config
                        ).to(device='cuda')
attn_num_params = count_parameters(hrnet_attention_model)
print(f"\tNumber of parameters with attention: {attn_num_params/1e6} million")

print('Running attention-based HRNET timing test...')
attn_time = time_run(input, hrnet_attention_model, reps)

# Clean up model
hrnet_attention_model.to(device='cpu')
torch.cuda.empty_cache()
del hrnet_attention_model
torch.cuda.empty_cache()

In [None]:
TODO

print('Creating hyena-based Swin model...')
                    
swin_hyena_model = SwinTransformer(
                            use_hyena=True,
                            in_chans=in_ch,
                            embed_dim=48,
                            window_size=swin_window,
                            patch_size=swin_patch,
                            depths=[2,2,6,2],
                            num_heads=[3,6,12,24],
                            spatial_dims=3,
                        ).to(device='cuda')
hyena_num_params = count_parameters(swin_hyena_model)
print(f"\tNumber of parameters with hyena: {hyena_num_params/1e6} million")

print('Running hyena-based Swin timing test...')
hyena_time = time_run(input, swin_hyena_model, reps)

# Clean up model
swin_hyena_model.to(device='cpu')
torch.cuda.empty_cache()
del swin_hyena_model
torch.cuda.empty_cache()

Print results

In [None]:
TODO

# Print speedup
speedup = [a/h for a, h in zip(attn_time, hyena_time)]
print(f"Average speedup: {np.mean(speedup)}x")
if np.abs(np.mean(speedup)-np.median(speedup)>0.1):
    print("\tWARNING: Mean and median are very different")
    print(f"\tMedian speedup: {np.median(speedup)}x")

# Plot speedup
plt.figure(figsize=(4,3))
plt.plot(speedup)
plt.xlabel("Repetition")
plt.ylabel("Speedup (attention/hyena)")  
plt.show()
plt.close()

# Plot raw timing data
plt.figure(figsize=(4,3))
plt.plot(attn_time,'r',label='Attention')
plt.plot(hyena_time,'b',label='Hyena')
plt.xlabel("Repetition")
plt.ylabel("Time (s)")  
plt.legend()
plt.show()
plt.close()



Save results into dict

In [None]:
TODO

speedup_dict[key_count] = {'image_size': (batch,in_ch,h,w,t), 
                            'context_length':context_length, 
                            'speedup':np.mean(speedup),
                            'attn_time':np.mean(attn_time),
                            'hyena_time':np.mean(hyena_time), 
                            'attn_num_param':attn_num_params,
                            'hyena_num_param':hyena_num_params}
key_count+=1
with open('/home/hoopersm/hyena/swin_speedup_dict.pkl', 'wb') as f:
    pickle.dump(speedup_dict, f, pickle.HIGHEST_PROTOCOL)

Plot results from all experiments

In [None]:
TODO

with open('/home/hoopersm/hyena/swin_speedup_dict.pkl', 'rb') as f:
    speedup_dict = pickle.load(f)
key_count = len(speedup_dict)

all_speedup = [speedup_dict[key]['speedup'] for key in speedup_dict.keys()]
all_context_length = [speedup_dict[key]['context_length'] for key in speedup_dict.keys()]

plt.figure(figsize=(4,3))
plt.plot(all_context_length,all_speedup,'rx')
plt.title("ViT speedup vs. context length")
plt.xlabel("Context length")
plt.ylabel("Speedup")
plt.show()
plt.close()