In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# SITS Tokenization

In [121]:
class PatchEmbedding(nn.Module):
    def __init__(self, img_height=128, img_width=128, patch_size=16, in_channel=10, embed_dim=768):
        super().__init__()
        self.H = img_height
        self.W = img_width
        self.P = patch_size
        self.C = in_channel
        self.d = embed_dim

        self.N = int(self.H * self.W // self.P**2)
        self.n = int(self.N**0.5)
        self.nh = int(self.H / self.P)
        self.nw = int(self.W / self.P)

        self.projection = nn.Conv3d(self.C, self.d, kernel_size=(1, self.P, self.P), stride=(1, self.P, self.P))

        # self.cls_token = nn.Parameter(torch.zeros(1, 1, self.d)) # (1, 1, d)

        self.pos_emb = nn.Parameter(torch.zeros(self.d, 1, self.N)) # (d, 1, N)



    def forward(self, x):
        B, T, C, H, W = x.shape # (B, T, C, H, W)

        pos_emb = self.pos_emb # (d, 1, N)
        pos_emb = pos_emb.expand(-1, T, -1) # (d, T, N)

        x = x.view(B, C, T, H, W) # (B, C, T, H, W)
        x = self.projection(x) # (B, d, T, nw, nh)
        x = x.view(B, self.d, T, self.nh*self.nw) # (B, d, T, nh*nw) ~ (B, d, T, N)

        # Add Positional Embeddings
        x = x + pos_emb # (B, d, T, N)

        return x



x = torch.randn(4, 30, 10, 128, 128)
PatchEmbedding(128,128,16,10,768)(x).shape

torch.Size([4, 768, 30, 64])

# Temporal Encoding

In [103]:
class TemporalEncoder(nn.Module):
    def __init__(self):
        super.__init__()
        
        
        
