In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

BATCH_SIZE = 2
FEAT_DIM   = 22
HIDDEN     = 64
SEQ_LEN    = 60
TARGET_LEN = 100
LEVELS     = 3
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def pad_time(x: torch.Tensor, tgt_len: int):
    B, _, C, L = x.shape
    if L == tgt_len:
        mask = x.new_ones((B,1,1,L))
        print(f"(pad_time) exact, no pad. L={L}")
        return x, mask
    if L > tgt_len:
        print(f"(pad_time) truncate from L={L} to {tgt_len}")
        x = x[..., :tgt_len]
        mask = x.new_ones((B,1,1,tgt_len))
        return x, mask
    pad = tgt_len - L
    print(f"(pad_time) right-pad from L={L} to {tgt_len} (+{pad})")
    x = F.pad(x, (0, pad))
    mask = torch.cat([x.new_ones((B,1,1,L)), x.new_zeros((B,1,1,pad))], dim=-1)
    return x, mask

class SimpleResBlock1D(nn.Module):
    def __init__(self, channels: int, kernel_size: int = 3):
        super().__init__()
        padding = kernel_size // 2
        self.conv1 = nn.Conv1d(channels, channels, kernel_size, padding=padding, bias=False)
        self.conv2 = nn.Conv1d(channels, channels, kernel_size, padding=padding, bias=False)
    def forward(self, x: torch.Tensor, gamma: torch.Tensor | None = None, beta: torch.Tensor | None = None):
        y = self.conv1(x)
        y = self.conv2(y)
        return y + x

class GistUNet1D(nn.Module):
    def __init__(self, feat_dim=FEAT_DIM, hidden=HIDDEN, levels=LEVELS):
        super().__init__()
        self.levels = levels
        self.in_proj  = nn.Conv1d(in_channels=feat_dim, out_channels=hidden, kernel_size=1, bias=False)
        self.out_proj = nn.Conv1d(in_channels=hidden,   out_channels=feat_dim, kernel_size=1, bias=False)
        self.down_blocks = nn.ModuleList([SimpleResBlock1D(hidden) for _ in range(levels)])
        self.downsample  = nn.ModuleList([
            nn.Conv1d(hidden, hidden, kernel_size=4, stride=2, padding=1, bias=False)
            for _ in range(levels)
        ])
        self.mid1 = SimpleResBlock1D(hidden)
        self.mid2 = SimpleResBlock1D(hidden)
        self.upsample = nn.ModuleList([
            nn.ConvTranspose1d(hidden, hidden, kernel_size=4, stride=2, padding=1, bias=False)
            for _ in range(levels)
        ])
        self.up_blocks = nn.ModuleList([SimpleResBlock1D(hidden) for _ in range(levels)])
        self.out_conv = nn.Conv1d(hidden, hidden, kernel_size=3, padding=1, bias=False)
    def forward(self, x):
        x = x.squeeze(1)
        x = self.in_proj(x)
        print(f"After bt1/in_proj    : {tuple(x.shape)}")
        skips = []
        for i, (block, down) in enumerate(zip(self.down_blocks, self.downsample), 1):
            x = block(x, None, None)
            print(f"Down {i} (resblock)  : {tuple(x.shape)}")
            skips.append(x)
            x = down(x)
            print(f"Down {i} (downsample): {tuple(x.shape)}")
        x = self.mid1(x)
        x = self.mid2(x)
        print(f"Mid (2x resblocks)   : {tuple(x.shape)}")
        for j, (up, block) in enumerate(zip(self.upsample, self.up_blocks), 1):
            x = up(x)
            print(f"Up {j} (upsample)    : {tuple(x.shape)}")
            skip = skips[-j]
            if x.size(-1) > skip.size(-1):
                print(f"Align (crop)        : {tuple(x.shape)} -> {skip.size(-1)}")
                x = x[..., :skip.size(-1)]
            elif x.size(-1) < skip.size(-1):
                pad = skip.size(-1) - x.size(-1)
                print(f"Align (pad)         : {tuple(x.shape)} -> +{pad} to {skip.size(-1)}")
                x = F.pad(x, (0, pad))
            x = x + skip
            x = block(x, None, None)
            print(f"Up {j} (+skip,resblk): {tuple(x.shape)}")
        x = self.out_conv(x)
        print(f"After out_conv       : {tuple(x.shape)}")
        x = self.out_proj(x)
        print(f"After bt2/out_proj   : {tuple(x.shape)}")
        x = x.unsqueeze(1)
        print(f"Final output         : {tuple(x.shape)}")
        return x

torch.manual_seed(0)
x = torch.rand(BATCH_SIZE, 1, FEAT_DIM, SEQ_LEN, device=device)
x_pad, mask = pad_time(x, TARGET_LEN)
print(f"Input (after padding) : {tuple(x_pad.shape)}")
model = GistUNet1D(FEAT_DIM, HIDDEN, LEVELS).to(device)
with torch.no_grad():
    y = model(x_pad)


(pad_time) right-pad from L=60 to 100 (+40)
Input (after padding) : (2, 1, 22, 100)
After bt1/in_proj    : (2, 64, 100)
Down 1 (resblock)  : (2, 64, 100)
Down 1 (downsample): (2, 64, 50)
Down 2 (resblock)  : (2, 64, 50)
Down 2 (downsample): (2, 64, 25)
Down 3 (resblock)  : (2, 64, 25)
Down 3 (downsample): (2, 64, 12)
Mid (2x resblocks)   : (2, 64, 12)
Up 1 (upsample)    : (2, 64, 24)
Align (pad)         : (2, 64, 24) -> +1 to 25
Up 1 (+skip,resblk): (2, 64, 25)
Up 2 (upsample)    : (2, 64, 50)
Up 2 (+skip,resblk): (2, 64, 50)
Up 3 (upsample)    : (2, 64, 100)
Up 3 (+skip,resblk): (2, 64, 100)
After out_conv       : (2, 64, 100)
After bt2/out_proj   : (2, 22, 100)
Final output         : (2, 1, 22, 100)
