In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '4'

%matplotlib inline
%load_ext autoreload
%autoreload 2

import numpy as np
import matplotlib.pyplot as plt
from os import path

from tqdm import tqdm

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from torch.optim import AdamW, Adam
from torch import autocast, GradScaler

from omegaconf import OmegaConf, open_dict

torch.backends.cuda.matmul.allow_tf32 = False
torch.set_float32_matmul_precision('high')

In [2]:
from hiera import Hiera
tiny_hiera = Hiera(input_size=(36, 64),
                        num_heads=1,
                        embed_dim=96,
                        stages=(2, 1,), # 3 transformer layers 
                        q_pool=1, 
                        in_chans=1,
                        q_stride=(1, 1,),
                        mask_unit_size=(8, 8),
                        patch_kernel=(5, 5),
                        patch_stride=(2, 2),
                        patch_padding=(2, 2),
                        sep_pos_embed=False, # True for 3D
                        drop_path_rate=1,
                        mlp_ratio=4,)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from tqdm import tqdm

tiny_hiera = tiny_hiera.cuda().to(torch.bfloat16);
example_input = torch.ones(256, 1, 36,64).to("cuda", torch.bfloat16)

for i in tqdm(range(2000)):
    out = tiny_hiera(example_input, return_intermediates=True);

 42% 833/2000 [00:10<00:15, 76.60it/s]

KeyboardInterrupt



In [9]:
from tqdm import tqdm

#tiny_hiera = torch.compile(tiny_hiera).cuda().to(torch.bfloat16);
example_input = torch.ones(256, 1, 36,64).to("cuda", torch.bfloat16)

for i in tqdm(range(2000)):
    out = tiny_hiera(example_input, return_intermediates=True);

100% 2000/2000 [00:16<00:00, 124.71it/s]


In [15]:
# pip install hiera-transformer
from hiera import Hiera
tiny_hiera = Hiera(input_size=(60,36,64),
                     num_heads=3,
                     embed_dim=96,
                     stages=(2, 1,), # 3 transformer layers 
                     q_pool=1, 
                     in_chans= 6,
                     q_stride=(1, 1, 1,),
                     mask_unit_size=(1, 8, 8),
                     patch_kernel=(5, 5, 5),
                     patch_stride=(3, 2, 2),
                     patch_padding=(1, 2, 2),
                     sep_pos_embed=True,
                     drop_path_rate=0,
                     mlp_ratio=4,)

tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(8, 6, 60, 36,64).to("cuda", torch.float32)
out = tiny_hiera(example_input, return_intermediates=True);

hiera_output = out[-1][-1]
hiera_output.shape # (b, t, h, w, c): (8, 4, 9, 16, 192)

torch.Size([8, 20, 18, 32, 192])

In [101]:
from tqdm import tqdm

tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(8, 6, 60, 36,64).to("cuda", torch.float32)

#with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
for i in tqdm(range(2000)):
    out = tiny_hiera(example_input, return_intermediates=True);

  0% 0/2000 [00:00<?, ?it/s]


RuntimeError: Expected 3D (unbatched) or 4D (batched) input to conv2d, but got input of size: [8, 6, 60, 36, 64]

In [None]:
sdpa_kernel(SDPBackend.CUDNN_ATTENTION)

In [17]:
%debug

> [0;32m/src/src/hiera/hiera/hiera.py[0m(100)[0;36mforward[0;34m()[0m
[0;32m     98 [0;31m        [0;32mif[0m [0mhasattr[0m[0;34m([0m[0mF[0m[0;34m,[0m [0;34m"scaled_dot_product_attention"[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     99 [0;31m            [0;31m# Note: the original paper did *not* use SDPA, it's a free boost![0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 100 [0;31m            [0mx[0m [0;34m=[0m [0mF[0m[0;34m.[0m[0mscaled_dot_product_attention[0m[0;34m([0m[0mq[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mk[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;34m)[0m[0;34m,[0m [0mv[0m[0;34m.[0m[0msqueeze[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    101 [0;31m[0;34m[0m[0m
[0m[0;32m    102 [0;31m        [0;32melse[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  what = q
ipdb>  what.shape


torch.Size([8, 3, 180, 64, 32])


ipdb>  180*64


11520


ipdb>  q


In [78]:
query = torch.randn(8, 3, 180, 64, 32, device="cuda", dtype=torch.float32)

In [79]:
key = torch.randn(8, 3, 180, 64, 32, device="cuda", dtype=torch.float32)
value = torch.randn(8, 3, 180, 64, 32, device="cuda", dtype=torch.float32)

In [80]:
out1 = F.scaled_dot_product_attention(query, key, value)
out1 = out1.view(8,3,-1,32)

In [81]:
out1.shape

torch.Size([8, 3, 11520, 32])

In [87]:
out1[0,0,0]

tensor([-0.2637, -0.2344,  0.0146,  0.1025, -0.2236,  0.0581, -0.1719,  0.0815,
        -0.0195, -0.3320,  0.0136, -0.0479,  0.3066,  0.2227, -0.0270, -0.0486,
        -0.0483,  0.0420,  0.1221, -0.2275, -0.1650,  0.0359, -0.1006,  0.0698,
        -0.0742, -0.1050,  0.1230, -0.0354, -0.1387,  0.0840,  0.2217, -0.0145],
       device='cuda:0', dtype=torch.bfloat16)

In [88]:
out1.shape

torch.Size([8, 3, 11520, 32])

In [89]:
q = query.view(8,3,-1,32)
k = key.view(8,3,-1,32)
v = value.view(8,3,-1,32)

In [90]:
out2 = F.scaled_dot_product_attention(q, k, v)

In [91]:
out2.shape

torch.Size([8, 3, 11520, 32])

In [92]:
out2[0,0,0]

tensor([-0.0189, -0.0203, -0.0181, -0.0305,  0.0162, -0.0062,  0.0065, -0.0192,
         0.0204,  0.0107, -0.0048,  0.0121,  0.0064,  0.0053, -0.0002, -0.0018,
         0.0081, -0.0159, -0.0095,  0.0129, -0.0050,  0.0137,  0.0077, -0.0032,
         0.0454, -0.0144, -0.0198, -0.0021, -0.0165,  0.0101,  0.0009, -0.0031],
       device='cuda:0', dtype=torch.bfloat16)

In [93]:
out2.shape

torch.Size([8, 3, 11520, 32])

In [49]:
d = q @ k.transpose(-2, -1)

In [55]:
from torch.nn import functional as F

In [7]:
tiny_hiera = Hiera(embed_dim=96, num_heads=1, stages=(1, 2, 7, 2))

In [9]:
tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(32, 3, 224,224).to("cuda", torch.float32)
out = tiny_hiera(example_input, return_intermediates=True);

hiera_output = out[-1][-1]
hiera_output.shape # (b, t, h, w, c): (8, 4, 9, 16, 192)

torch.Size([32, 7, 7, 768])

In [10]:
from tqdm import tqdm

tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(128, 3, 224,224).to("cuda", torch.float32)

for i in tqdm(range(500)):
    out = tiny_hiera(example_input, return_intermediates=True);

100% 500/500 [00:29<00:00, 17.13it/s]


In [16]:
from tqdm import tqdm

tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(128, 1, 36,64).to("cuda", torch.float32)

for i in tqdm(range(2000)):
    out = tiny_hiera(example_input, return_intermediates=True);

100% 2000/2000 [00:16<00:00, 120.24it/s]


In [15]:
121*128

15488

In [4]:
# pip install hiera-transformer
from hiera import Hiera
tiny_hiera = Hiera(input_size=(36, 64),
                     num_heads=3,
                     embed_dim=96,
                     stages=(2, 1,), # 3 transformer layers 
                     q_pool=1, 
                     in_chans= 1,
                     q_stride=(1, 1, 1,),
                     mask_unit_size=(1, 8, 8),
                     patch_kernel=(5, 5, 5),
                     patch_stride=(3, 2, 2),
                     patch_padding=(1, 2, 2),
                     sep_pos_embed=True,
                     drop_path_rate=0,
                     mlp_ratio=4,)

tiny_hiera = tiny_hiera.cuda().to(torch.float32);
example_input = torch.ones(32, 1, 36,64).to("cuda", torch.float32)
out = tiny_hiera(example_input, return_intermediates=True);

hiera_output = out[-1][-1]
hiera_output.shape # (b, t, h, w, c): (8, 4, 9, 16, 192)

IndexError: list index out of range