In [3]:
import torch as th
import numpy as np
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
sys.path.insert(0, "/home/mint/Dev/SkelAg/DiffSynth-Studio")

from diffsynth.models.wan_video_dit import DiTBlock, SelfAttention, rearrange, precompute_freqs_cis_3d, precompute_freqs_cis, sinusoidal_embedding_1d
from diffsynth.models.wan_video_vae import Decoder3d, CausalConv3d, VideoVAE_, WanVideoVAE

n_timesteps = 1000
use_time = [1000]
n_blocks = 30
use_block = [30]
B = 1
D = 1536
F = 256    # Flattened spatial-temporal dimension (1+T/4 * H/16 * W/16); H, W = 128
feats_dim = (B, F, D)

feats = np.zeros((n_timesteps, n_blocks, B, F, D), dtype=np.uint8)
print(feats.shape)

use_feats = []
for t in use_time:
    for b in use_block:
        # print(f"t: {t}, b: {b}")
        use_feats.append(feats[t-1, b-1])
        # print(use_feats[0].shape)

use_feats = th.tensor(np.stack(use_feats, axis=0))
print("use_feats shape:", use_feats.shape)


(1000, 30, 1, 256, 1536)
use_feats shape: torch.Size([1, 1, 256, 1536])


In [4]:
from einops import repeat, reduce
from PIL import Image
def vae_output_to_video(vae_output, pattern="B C T H W", min_value=-1, max_value=1):
    # Transform a torch.Tensor to list of PIL.Image
    if pattern != "T H W C":
        vae_output = reduce(vae_output, f"{pattern} -> T H W C", reduction="mean")
    video = [vae_output_to_image(image, pattern="H W C", min_value=min_value, max_value=max_value) for image in vae_output]
    return video
def vae_output_to_image(vae_output, pattern="B C H W", min_value=-1, max_value=1):
    # Transform a torch.Tensor to PIL.Image
    if pattern != "H W C":
        vae_output = reduce(vae_output, f"{pattern} -> H W C", reduction="mean")
    image = ((vae_output - min_value) * (255 / (max_value - min_value))).clip(0, 255)
    image = image.to(device="cpu", dtype=th.uint8)
    image = Image.fromarray(image.numpy())
    return image

num_heads = 8
D = 1536
J = 5
head_dim = D // num_heads  # 192
print("head_dim:", head_dim)

x_flat = rearrange(use_feats, 'n b f d -> b (n f) d').cuda().float()   # [1, 4096, 1536]
print("x_flat shape:", x_flat.shape)

self_attn = SelfAttention(dim=D, num_heads=num_heads).to(x_flat.device)
# decoder = Decoder3d(dim=96, z_dim=16, dim_mult=[1, 2, 4, 4], num_res_blocks=2,
#                             attn_scales=[], temporal_upsample=[False, True, True])
decoder = Decoder3d().cuda()

# 1D RoPE over sequence length
freqs = precompute_freqs_cis(head_dim, end=x_flat.shape[1], theta=10000.0).to(x_flat.device)
print("SelfAttention:", self_attn)
print("freqs.shape:", freqs.shape)
print("x_flat.shape:", x_flat.shape)

freqs = freqs[None, :, None, :]
out_flat = self_attn(x_flat, freqs=freqs)       # [1, 4096, 1536]
print("out_flat shape:", out_flat.shape)
out = rearrange(out_flat, 'b (n f) d -> n b f d', n=len(use_time) * len(use_block), f=F)
print("out shape:", out.shape)

grid_size = [16, 4, 4]
patch_size = [1, 2, 2]

head = th.nn.Linear(D, 64).to(out.device)
out = head(out)
print("out shape after head:", out.shape)

joint_head = th.nn.Linear(F, F * J).to(out.device)
out = rearrange(out, 'n b f d -> n b d f')
print(out.shape)
out = joint_head(out)
print("out shape after joint_head:", out.shape)
out_joint_combined = rearrange(out, 'n b d (j f) -> n b j f d', j=J)
print("out_joint_combined shape:", out_joint_combined.shape)


#TODO: Should be replace with some layers, Pooling, etc. to combined the multiple time/block features
out = out_joint_combined.squeeze(0)  # [B, F, C]
print("out shape after squeeze:", out.shape)
out_unpatch = rearrange(
            out, 'b j (f h w) (x y z c) -> b j c (f x) (h y) (w z)',
            f=grid_size[0], h=grid_size[1], w=grid_size[2], 
            x=patch_size[0], y=patch_size[1], z=patch_size[2], j = J
        )
print("out_unpatch shape:", out_unpatch.shape)

# tiled = False
# tile_size = None
# tile_stride = None
tiled = True,
tile_size = (30, 52)
tile_stride = (15, 26)
video_vae = WanVideoVAE().cpu()
out_unpatch = out_unpatch.cpu()

out_joint = []
print("out to decoding: ", out_unpatch.shape)
print("out to decoding: ", out_unpatch[0:1, 0, ...].shape)
assert False
import tqdm
for i in tqdm.tqdm(range(J), desc="Decoding joints"):
    video = video_vae.decode(out_unpatch[0:1, i, ...], device=out_unpatch.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
    out_joint.append(video)
    print("video shape:", video.shape)
# video = vae_output_to_video(video)
# print("video length:", len(video))


head_dim: 192
x_flat shape: torch.Size([1, 256, 1536])
SelfAttention: SelfAttention(
  (q): Linear(in_features=1536, out_features=1536, bias=True)
  (k): Linear(in_features=1536, out_features=1536, bias=True)
  (v): Linear(in_features=1536, out_features=1536, bias=True)
  (o): Linear(in_features=1536, out_features=1536, bias=True)
  (norm_q): RMSNorm()
  (norm_k): RMSNorm()
  (attn): AttentionModule()
)
freqs.shape: torch.Size([256, 96])
x_flat.shape: torch.Size([1, 256, 1536])
x shape before rope: torch.Size([1, 256, 1536])
x shape after rearrange: torch.Size([1, 256, 8, 192])
x_out shape as complex: torch.Size([1, 256, 8, 96])
freqs:  torch.Size([1, 256, 1, 96])
x_out shape after rope: torch.Size([1, 256, 1536])
x shape before rope: torch.Size([1, 256, 1536])
x shape after rearrange: torch.Size([1, 256, 8, 192])
x_out shape as complex: torch.Size([1, 256, 8, 96])
freqs:  torch.Size([1, 256, 1, 96])
x_out shape after rope: torch.Size([1, 256, 1536])
out_flat shape: torch.Size([1, 256,

AssertionError: 

In [None]:
import matplotlib.pyplot as plt
print(video.shape)
plt.imshow(video[0][:, 0, ...].permute(1, 2, 0).cpu().detach().numpy())
plt.show()


Simple arch
- Upsample + prediction head as 2N

In [None]:
# # upsample block
# if i != len(dim_mult) - 1:
#     mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
#     upsamples.append(Resample(out_dim, mode=mode))
#     scale *= 2.0
temperal_upsample = [True, True, False]
from diffsynth.models.wan_video_vae import Decoder3d, CausalConv3d, VideoVAE_, WanVideoVAE, Resample
dim_mult = [1, 2, 4, 4]
dim = 96
z_dim = 16
tile_size
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
print(dims)
cur_dim = dims[0]  # whatever your starting channels are

upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
    print(f"Upsample block {i}: in_dim={in_dim}, out_dim={out_dim}, cur_dim={cur_dim}")
    if i != len(dim_mult) - 1:
        mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
        upsamples.append(Resample(cur_dim, mode=mode))
        cur_dim = cur_dim // 2   # <-- IMPORTANT: match Conv2d(cur_dim -> cur_dim//2)

upsamples = th.nn.Sequential(*upsamples).cuda()
conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1).cuda()
inp = th.randn(1, 16, 16, 8, 8).float().cuda()
print("inp shape:", inp.shape)
print("conv1(inp).shape:", conv1(inp).shape)
# print(upsamples)
# upsamples[1](upsamples[0](conv1(inp)))
print("upsamples output shape:", upsamples( conv1(inp) ).shape)

#TODO: integrate with tiled decoding


[384, 384, 384, 192, 96]
Upsample block 0: in_dim=384, out_dim=384, cur_dim=384
Upsample block 1: in_dim=384, out_dim=384, cur_dim=192
Upsample block 2: in_dim=384, out_dim=192, cur_dim=96
Upsample block 3: in_dim=192, out_dim=96, cur_dim=48
inp shape: torch.Size([1, 16, 16, 8, 8])
conv1(inp).shape: torch.Size([1, 384, 16, 8, 8])
16
16
16
upsamples output shape: torch.Size([1, 48, 16, 64, 64])


In [18]:
import tqdm
def count_conv3d(model):
    count = 0
    for m in model.modules():
        if isinstance(m, CausalConv3d):
            count += 1
    return count

class decoder_simple(th.nn.Module):
    def __init__(self):
        super().__init__()
        dim_mult = [1, 2, 4, 4]
        dim = 96
        z_dim = 16
        tile_size
        dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
        self.upsampling_factor = 8
        print(dims)
        cur_dim = dims[0]  # whatever your starting channels are
        temperal_upsample = [False, True, True][::-1]

        upsamples = []
        for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
            if i != len(dim_mult) - 1:
                mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
                print(mode)
                upsamples.append(Resample(cur_dim, mode=mode))
                cur_dim = cur_dim // 2   # <-- IMPORTANT: match Conv2d(cur_dim -> cur_dim//2)

        self.resample_module = th.nn.Sequential(*upsamples).cuda()
        self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1).cuda()
        self.conv2 = CausalConv3d(z_dim, z_dim, 1)

    def clear_cache(self):
        self._conv_num = count_conv3d(self.resample_module) + count_conv3d(self.conv1) + count_conv3d(self.conv2)
        self._conv_idx = [0]
        self._feat_map = [None] * self._conv_num
        print("Cleared cache:", self._conv_num, self._conv_idx, len(self._feat_map))
    
    def build_1d_mask(self, length, left_bound, right_bound, border_width):
        x = th.ones((length,))
        if not left_bound:
            x[:border_width] = (th.arange(border_width) + 1) / border_width
        if not right_bound:
            x[-border_width:] = th.flip((th.arange(border_width) + 1) / border_width, dims=(0,))
        return x


    def build_mask(self, data, is_bound, border_width):
        _, _, _, H, W = data.shape
        h = self.build_1d_mask(H, is_bound[0], is_bound[1], border_width[0])
        w = self.build_1d_mask(W, is_bound[2], is_bound[3], border_width[1])

        h = repeat(h, "H -> H W", H=H, W=W)
        w = repeat(w, "W -> H W", H=H, W=W)

        mask = th.stack([h, w]).min(dim=0).values
        mask = rearrange(mask, "H W -> 1 1 1 H W")
        return mask
        
    def forward_pass(self, x, feat_cache=None, feat_idx=[0]):
        x = self.conv1(x)
        print("fw pass (conv1): ", x.shape)
        return self.resample_module(x, feat_cache=feat_cache, feat_idx=feat_idx)
    
    def decode_fn(self, hidden_states, device, tiled=False, tile_size=(34, 34), tile_stride=(18, 16)):
        #NOTE: From WanVideoVAE.decode
        hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]  # batched tensor -> list of tensor
        videos = []
        for hidden_state in hidden_states:
            hidden_state = hidden_state.unsqueeze(0)    # add batch dim
            if tiled:
                print("Using tiled decoding", hidden_state.shape)
                video = self.tiled_decode(hidden_state, device, tile_size, tile_stride)
                print("video shape after tiled decode:", video.shape)
            else:
                video = self.single_decode(hidden_state, device)
            video = video.squeeze(0)
            videos.append(video)
        videos = th.stack(videos)
        return videos
    
    def decode(self, z, tile_size, tile_stride):
        self.clear_cache()
        #NOTE: From VideoVAE_.decode
        print("In VideoVAE_.decode: z's shape =  ", z.shape)
        iter_ = z.shape[2]
        x = self.conv2(z)
        for i in range(iter_):
            self._conv_idx = [0]
            if i == 0:
                out = self.forward_pass(x[:, :, i:i + 1, :, :],
                                   feat_cache=self._feat_map,
                                   feat_idx=self._conv_idx)
            else:
                out_ = self.forward_pass(x[:, :, i:i + 1, :, :],
                                    feat_cache=self._feat_map,
                                    feat_idx=self._conv_idx)
                out = th.cat([out, out_], 2) # may add tensor offload
        # for i in range(iter_):
            # print(f"{i} - {x[:, :, i:i+1, :, :].shape}")
        #     out = self.model(x[:, :, i:i+1, :, :])
        #     all_out.append(out)
        # all_out = th.cat(all_out, dim=2)
        print("In VideoVAE_.decode: all_out's shape = ", out.shape)
        return out
    
    def tiled_decode(self, hidden_states, device, tile_size, tile_stride):
        _, _, T, H, W = hidden_states.shape
        print("T, H, W:", T, H, W)
        size_h, size_w = tile_size
        stride_h, stride_w = tile_stride

        # Split tasks
        tasks = []
        for h in range(0, H, stride_h):
            if (h-stride_h >= 0 and h-stride_h+size_h >= H): continue
            for w in range(0, W, stride_w):
                if (w-stride_w >= 0 and w-stride_w+size_w >= W): continue
                h_, w_ = h + size_h, w + size_w
                tasks.append((h, h_, w, w_))

        data_device = "cpu"
        computation_device = device

        out_T = T * 4 - 3
        weight = th.zeros((1, 1, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
        values = th.zeros((1, 3, out_T, H * self.upsampling_factor, W * self.upsampling_factor), dtype=hidden_states.dtype, device=data_device)
        print("weight shape:", weight.shape)
        print("values shape:", values.shape)

        for h, h_, w, w_ in tqdm.tqdm(tasks, desc="VAE decoding"):
            print(f"Decoding tile: h: {h}-{h_}, w: {w}-{w_}")
            print("hidden_states.shape:", hidden_states.shape)
            hidden_states_batch = hidden_states[:, :, :, h:h_, w:w_].to(computation_device)
            # hidden_states_batch = self.model.decode(hidden_states_batch, self.scale).to(data_device)
            hidden_states_batch = self.decode(hidden_states_batch, tile_size, tile_stride).to(data_device)
            print("decoded hidden_states_batch.shape:", hidden_states_batch.shape)

            mask = self.build_mask(
                hidden_states_batch,
                is_bound=(h==0, h_>=H, w==0, w_>=W),
                border_width=((size_h - stride_h) * self.upsampling_factor, (size_w - stride_w) * self.upsampling_factor)
            ).to(dtype=hidden_states.dtype, device=data_device)
            
            print("mask.shape:", mask.shape)
            print("values.shape:", values.shape)
            print("hidden_states_batch.shape:", hidden_states_batch.shape)
            print("weight.shape:", weight.shape)
            

            target_h = h * self.upsampling_factor
            target_w = w * self.upsampling_factor
            values[
                :,
                :,
                :,
                target_h:target_h + hidden_states_batch.shape[3],
                target_w:target_w + hidden_states_batch.shape[4],
            ] += hidden_states_batch * mask
            weight[
                :,
                :,
                :,
                target_h: target_h + hidden_states_batch.shape[3],
                target_w: target_w + hidden_states_batch.shape[4],
            ] += mask
        values = values / weight
        values = values.clamp_(-1, 1)
        print(values.shape)
        return values
    
    def single_decode(self, hidden_state, device):
        hidden_state = hidden_state.to(device)
        video = self.model.decode(hidden_state, self.scale)
        return video.clamp_(-1, 1)
    

tiled = True,
tile_size = (30, 52)
tile_stride = (15, 26)
decoder = decoder_simple().cuda()
inp = th.randn(1, 16, 16, 8, 8).float().cuda()  # B, C, T, H, W
# print(decoder.upsamples[0].mode)
# decoder.upsamples(conv1(inp))
# print(decoder(inp).shape)
# out = decoder.tiled_decode(inp, device=inp.device, tile_size=tile_size, tile_stride=tile_stride)
out = decoder.decode_fn(inp, device=inp.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)

[384, 384, 384, 192, 96]
upsample3d
upsample3d
upsample2d
Using tiled decoding torch.Size([1, 16, 16, 8, 8])
T, H, W: 16 8 8
weight shape: torch.Size([1, 1, 61, 64, 64])
values shape: torch.Size([1, 3, 61, 64, 64])


VAE decoding:   0%|          | 0/1 [00:00<?, ?it/s]

Decoding tile: h: 0-30, w: 0-52
hidden_states.shape: torch.Size([1, 16, 16, 8, 8])
Cleared cache: 4 [0] 4
In VideoVAE_.decode: z's shape =   torch.Size([1, 16, 16, 8, 8])
fw pass (conv1):  torch.Size([1, 384, 1, 8, 8])





TypeError: Sequential.forward() got an unexpected keyword argument 'feat_cache'

In [None]:
hidden_states = th.randn(10, 16, 16, 8, 8).float().cuda()  # B, C, T, H, W
print("hidden_states shape:", hidden_states.shape)
hidden_states = [hidden_state.to("cpu") for hidden_state in hidden_states]
print(len(hidden_states))
print("hidden_states[0] shape:", hidden_states[0].shape)
_, _, T, H, W = hidden_states[0].shape

In [70]:
# # upsample block
# if i != len(dim_mult) - 1:
#     mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
#     upsamples.append(Resample(out_dim, mode=mode))
#     scale *= 2.0
temperal_upsample = [True, True, False]
from diffsynth.models.wan_video_vae import Decoder3d, CausalConv3d, VideoVAE_, WanVideoVAE, Resample
dim_mult = [1, 2, 4, 4]
dim = 96
z_dim = 16
tile_size
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
print(dims)
cur_dim = dims[0]  # whatever your starting channels are

upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
    print(f"Upsample block {i}: in_dim={in_dim}, out_dim={out_dim}, cur_dim={cur_dim}")
    if i != len(dim_mult) - 1:
        mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
        upsamples.append(Resample(cur_dim, mode=mode))
        cur_dim = cur_dim // 2   # <-- IMPORTANT: match Conv2d(cur_dim -> cur_dim//2)
# print(upsamples)
time_conv = CausalConv3d(16,
                            16 * 2, (3, 1, 1),
                            padding=(1, 0, 0)).cuda()
upsamples = th.nn.Sequential(*upsamples).cuda()
conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1).cuda()
print(conv1)
inp = th.randn(1, 16, 16, 8, 8).float().cuda()
print("inp shape:", inp.shape)
print("conv1(inp).shape:", conv1(inp).shape)
b, c, t, h, w = inp.shape
print("inp: ", inp.shape)
x = time_conv(inp)
# b, c, t, h, w = x.shape
print("inp after time_conv shape:", x.shape)
x = x.reshape(b, 2, c, t, h, w)
print("after reshape: ", x.shape)
x = th.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
                3)
print("after stack: ", x.shape)
x = x.reshape(b, c, t * 2, h, w)
print("after final reshape: ", x.shape)
#TODO: integrate with tiled decoding
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
print("before upsample: ", x.shape)
x = upsamples[0](x)
print("after upsample: ", x.shape)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
print("after upsample reshape: ", x.shape)



[384, 384, 384, 192, 96]
Upsample block 0: in_dim=384, out_dim=384, cur_dim=384
Upsample block 1: in_dim=384, out_dim=384, cur_dim=192
Upsample block 2: in_dim=384, out_dim=192, cur_dim=96
Upsample block 3: in_dim=192, out_dim=96, cur_dim=48
CausalConv3d(16, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))
inp shape: torch.Size([1, 16, 16, 8, 8])
conv1(inp).shape: torch.Size([1, 384, 16, 8, 8])
inp:  torch.Size([1, 16, 16, 8, 8])
inp after time_conv shape: torch.Size([1, 32, 16, 8, 8])
after reshape:  torch.Size([1, 2, 16, 16, 8, 8])
after stack:  torch.Size([1, 16, 16, 2, 8, 8])
after final reshape:  torch.Size([1, 16, 32, 8, 8])
before upsample:  torch.Size([32, 16, 8, 8])


ValueError: not enough values to unpack (expected 5, got 4)

In [114]:
import torch.nn.functional as F
from diffsynth.models.wan_video_vae import Upsample
x = th.randn(1, 16, 16, 8, 8).float().cuda()
b, c, t, h, w = x.shape
x = F.interpolate(x, size=(t*4-3, h, w), mode="trilinear", align_corners=True)
resample = th.nn.Sequential(
        Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
        th.nn.Conv2d(dims[0], dims[0] // 2, 3, padding=1)).cuda()
conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1).cuda()
upsample_dim = 96
dim_mult = [1, 2, 4, 4]
upsample_dims = [upsample_dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
print(upsample_dims)
# Upsample layers
upsamples = []
for i, (in_dim, out_dim) in enumerate(zip(upsample_dims[:-1], upsample_dims[1:])):
    # print(f"Upsample block {i}: in_dim={in_dim}, out_dim={out_dim}, cur_dim={cur_dim}")
        if i == 1 or i == 2 or i == 3:
            in_dim = in_dim // 2
            upsamples.append(CausalConv3d(in_dim, out_dim, 3, padding=1))
        if i != len(dim_mult) - 1:
                upsamples.append(th.nn.Sequential(
                        Upsample(scale_factor=(2., 2.), mode='nearest-exact'),
                        th.nn.Conv2d(out_dim, out_dim // 2, 3, padding=1)))

    # cur_dim = cur_dim // 2   # <-- IMPORTANT: match Conv2d(cur_dim -> cur_dim//2)
upsamples = th.nn.Sequential(*upsamples).cuda()
print(upsamples)
# Joint heatmap final conv
joint_map_conv = th.nn.Conv2d(upsample_dims[-1],  2 * J, 3, padding=1).cuda()
print("inp to conv1: ", x.shape)
x = conv1(x)
print("output from conv1: ", x.shape)
t = x.shape[2]
for layer in upsamples:
     print("inp to rearrange: ", x.shape)
     if not isinstance(layer, CausalConv3d):
        x = rearrange(x, 'b c t h w -> (b t) c h w')
     print("inp to resample: ", x.shape)
     print(layer)
     x = layer(x)
     print("after resample: ", x.shape)
     if not isinstance(layer, CausalConv3d):
        x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
     print(x.shape)

[384, 384, 384, 192, 96]
Sequential(
  (0): Sequential(
    (0): Upsample(scale_factor=(2.0, 2.0), mode='nearest-exact')
    (1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (1): CausalConv3d(192, 384, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (2): Sequential(
    (0): Upsample(scale_factor=(2.0, 2.0), mode='nearest-exact')
    (1): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (3): CausalConv3d(192, 192, kernel_size=(3, 3, 3), stride=(1, 1, 1))
  (4): Sequential(
    (0): Upsample(scale_factor=(2.0, 2.0), mode='nearest-exact')
    (1): Conv2d(192, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (5): CausalConv3d(96, 96, kernel_size=(3, 3, 3), stride=(1, 1, 1))
)
inp to conv1:  torch.Size([1, 16, 61, 8, 8])
output from conv1:  torch.Size([1, 384, 61, 8, 8])
inp to rearrange:  torch.Size([1, 384, 61, 8, 8])
inp to resample:  torch.Size([61, 384, 8, 8])
Sequential(
  (0): Upsample(scale_factor=(2.0, 2.0), mode='neare