In [6]:
import torch, torch.nn as nn

class HeteroDemo(nn.Module):
    """
    One Conv, one Encoder-Attn, one LSTM, one Dense.
    Bx1x60x90 → Bx10
    """
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1, 8, kernel_size=3, padding=1)          # CPU
        self.patch = nn.Unfold(kernel_size=(4,4), stride=(4,4))        # 15×22 patches
        self.attn = nn.TransformerEncoderLayer(d_model=128,
                                               nhead=4, batch_first=True) # GPU
        self.lstm = nn.LSTM(input_size=8*60*90 + 128, hidden_size=64,
                            num_layers=1, batch_first=True)            # CPU
        self.fc   = nn.Linear(64, 10)                                  # CPU

    def forward(self, x, h0=None):
        b = x.size(0)

        # 1) convolution branch
        a = self.conv(x).flatten(1)        # (B, 8*60*90)

        # 2) ViT-like branch
        p = self.patch(x)                  # (B, 8*4*4, 15*22)
        p = p.transpose(1,2)               # (B, 330, 128)
        print(p.shape)
        p = self.attn(p)[:,0]              # CLS token (B,128)

        # 3) fusion → LSTM
        fused = torch.cat([a, p], dim=-1).unsqueeze(1)  # time-len = 1
        y, h = self.lstm(fused, h0)        # keep state external

        # 4) classifier
        return self.fc(y.squeeze(1)), h

In [7]:

model = HeteroDemo().eval()
example = (torch.randn(1,1,60,90),)      # compile-time shape
output_example = model(*example)               # run-time shape
print("Output shape:", output_example[0].shape)  # (1,10)
print("Hidden shape:", output_example[1][0].shape)  # (1,1,64)

torch.Size([1, 330, 16])


AssertionError: was expecting embedding dimension of 128, but got 16