In [1]:
import torch
import torch.nn.functional as F
from einops import rearrange

In [2]:
L = 1024
h = 64
c = 256
k = torch.randn((c,h,L)) 

In [3]:
k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)

In [4]:
k0.shape

torch.Size([128, 64, 1024])

In [5]:
F.pad(k0, (0, L)).shape

torch.Size([128, 64, 2048])

In [6]:
F.pad(k1.flip(-1), (L, 0)).shape

torch.Size([128, 64, 2048])

In [75]:
def prop(x, k):
   #k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
   b, c, L = x.shape
   x_new = torch.cat([x, x.flip(-1)], dim = -2)
   #print(x_new.shape)
   k_f = torch.fft.rfft(k, n=2 * L)  # (H L)
   x_f = torch.fft.rfft(x_new, n=2 * L)  # (B H L)
   y = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L] 
   y = y[:, :c, :] + y[:, c:, :].flip(-1)
   
   return y

   #model = model.to("cpu")

In [108]:
def prop_dub(x, k):
   #k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
   b, c, L = x.shape
   k_f = torch.fft.rfft(k, n=2 * L)  # (H L)
   x_f = torch.fft.rfft(x, n=2 * L)  # (B H L)
   y1 = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L] 
   
   k_f = torch.fft.rfft(k, n=2 * L)  # (H L)
   x_f = torch.fft.rfft(x.flip(-1), n=2 * L)  # (B H L)
   y2 = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L] 
   
   y = y1 + y2.flip(-1)
   
   return y

In [107]:
def paper(x,k,og=True):
   #k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
   b, c, L = x.shape
   
   k0, k1 = rearrange(k, '(s c) h l -> s c h l', s=2)
   k = F.pad(k0, (0, L))  + F.pad(k1.flip(-1), (L, 0))
   #k = k0+ k1.flip(-1)
   
   k_f = torch.fft.rfft(k, n=2 * L)  # (H L)
   x_f = torch.fft.rfft(x, n=2 * L)  # (B H L)
   if og:
      y_f = torch.einsum('bhl,chl->bchl', x_f, k_f)
      y = torch.fft.irfft(y_f, n=2*L)[..., :L] # (B C H L)
   else:
      y = torch.fft.irfft(x_f * k_f, n=2 * L)[..., :L]  # (B H L)
   return y
    

In [116]:
from functools import partial
# model_throughput(partial(paper, og=False), name="paper", reps=200)# 593, 80
model_throughput(paper, name="paper", reps=200)# 593, 80
model_throughput(paper, name="paper", reps=200)# 593, 80
model_throughput(prop_dub, name="prop_dub", reps=200)# 255, 60
model_throughput(prop, name="prop", reps=200)# 261, 62


##################################################
paper
far/back mem GB: 0.7, 0.8
far/back speed b/s: 545.5, 67.2
##################################################
paper
far/back mem GB: 0.7, 0.8
far/back speed b/s: 548.1, 78.4
##################################################
prop_dub
far/back mem GB: 0.8, 0.9
far/back speed b/s: 257.8, 68.6
##################################################
prop
far/back mem GB: 0.9, 1.1
far/back speed b/s: 261.9, 62.1


In [117]:
from torch.optim import AdamW
import time
def model_throughput(func_, name="paper", reps=50):

   d = "cuda"
   L = 1024
   c = 256
   b = 16
   if name == "paper":
      k = torch.randn((2, c,L), requires_grad=True).to(d)
   
   elif name == "prop_dub":
      k = torch.randn((1,c,L), requires_grad=True).to(d)
      
   else:
      k = torch.randn((1,c*2,L), requires_grad=True).to(d)
   
   
   batch = torch.randn((b,c,L), requires_grad=True).to(d)
   #warm up
   #func_ = func_.eval()
   for _ in range(int(reps/5)):
      out = func_(batch, k)
   torch.cuda.synchronize()

   # farward
   torch.cuda.reset_peak_memory_stats()
   with torch.no_grad():
      t0 = time.perf_counter()
      for _ in range(reps):
         func_(batch, k)
      torch.cuda.synchronize()
      t1 = (time.perf_counter() - t0) / reps
      mem1 = torch.cuda.max_memory_allocated()

   # backward
   #func_ = func_.train()
   torch.cuda.reset_peak_memory_stats()
   t0 = time.perf_counter()
   for _ in range(int(reps)):
      (func_(batch, k)).sum().backward()
      #opt.step()
      #opt.zero_grad()
   torch.cuda.synchronize()
   t2 = (time.perf_counter() - t0) / reps
   mem2 = torch.cuda.max_memory_allocated()

   print("##################################################")
   print(name)
   print(f"far/back mem GB: {mem1/1e9:.1f}, {mem2/1e9:.1f}")
   print(f"far/back speed b/s: {1/t1:.1f}, {1/t2:.1f}")
