In [None]:
import torch
import torch.nn as nn

In [None]:
class ConvBlock(nn.Module):
  def __init__(self,in_c,out_c,kernel_size=3,padding=1):
    super().__init__()
    self.layers=nn.Sequential(
        nn.Conv2d(in_c,out_c,kernel_size=kernel_size,padding=padding),
        nn.BatchNorm2d(out_c),
        nn.ReLU(inplace=True)
    )
    def forward(self,input):
      return self.layers(input)

In [None]:
class DeconvBlock(nn.Module):
  def __init__(self,in_c,out_c):
    super().__init__()
    self.deconv=nn.ConvTranspose2d(in_c,out_c,kernel_size=2,stride=2,padding=0)
  def forward(self,input):
    return self.deconv(input)

In [None]:
class UNETR_2D(nn.Module):
  def __init__(self,cf):
    super().__init__()
    self.patch_embed=nn.Linear(
        cf["patch_size"]*cf["patch_size"]*cf["num_channels"],
        cf["hidden_dim"]
    )
    self.positions=torch.arrange(start=0,end=cf["num_patches"],step=1,dtype=torch.int32)
    self.pos_embed=nn.Embedding(cf["num_patches"],cf["hidden_dim"])
  def forward(self,input):
    patch_embed=self.patch_embed(input)
    print("Patch Embedding:",patch_embed.shape)
    pos_embed=self.pos_embed(input)
    print("Positional Embedding:",pos_embed.shape)
if __name__ == "__main__":
    config = {}
    config["image_size"] = 256
    config["num_layers"] = 12
    config["hidden_dim"] = 768
    config["mlp_dim"] = 3072
    config["num_heads"] = 12
    config["dropout_rate"] = 0.1
    config["num_patches"] = 256
    config["patch_size"] = 16
    config["num_channels"] = 3
    input=torch.randn((8,config["num_patches"],config["patch_size"]*config["patch_size"]*config["num_channels"]))
    print("Transformed Input:",input.shape)
    model=UNETR_2D(config)
    model(input)


Transformed Input: torch.Size([8, 256, 768])
Patch Embedding: torch.Size([8, 256, 768])


RuntimeError: ignored