TransUNet Encoder and Decoder, in the style that maybe can be used. 

In [None]:

# --------- Utility Blocks ---------
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Conv2d(in_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU(),
            nn.Conv2d(out_c, out_c, 3, padding=1),
            nn.BatchNorm2d(out_c),
            nn.ReLU()
        )
    def forward(self, x):
        return self.seq(x)

def center_crop(src, target):
    src_h, src_w = src.shape[-2:]
    tgt_h, tgt_w = target.shape[-2:]
    crop_h = (src_h - tgt_h) // 2
    crop_w = (src_w - tgt_w) // 2
    return src[..., crop_h:crop_h+tgt_h, crop_w:crop_w+tgt_w]

# --------- Encoder ---------
class TransUNetEncoder(nn.Module):
    def __init__(self, img_size, patch_size=8, d_model=128, nhead=4, num_blocks=4):
        super().__init__()
        self.enc1 = ConvBlock(1, 32)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(32, 64)
        self.pool2 = nn.MaxPool2d(2)
        self.patch_embed = nn.Conv2d(64, d_model, patch_size, patch_size)
        encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_blocks)
        self.patch_size = patch_size
        self.img_size = img_size
        self.d_model = d_model

    def forward(self, x):
        e1 = self.enc1(x)                # [B, 32, H, W]
        e2 = self.enc2(self.pool1(e1))   # [B, 64, H/2, W/2]
        x2 = self.pool2(e2)              # [B, 64, H/4, W/4]
        patches = self.patch_embed(x2)   # [B, d_model, ph, pw]
        B, C, H_, W_ = patches.shape
        patches_flat = patches.permute(0,2,3,1).reshape(B, H_*W_, C)
        tr_out = self.transformer(patches_flat)
        tr_out = tr_out.reshape(B, H_, W_, C).permute(0,3,1,2)  # [B, d_model, ph, pw]
        return tr_out, [e1, e2]

# --------- Decoder ---------
class TransUNetDecoder(nn.Module):
    def __init__(self, img_size, patch_size=8, d_model=128):
        super().__init__()
        ph, pw = img_size[0] // patch_size, img_size[1] // patch_size
        self.up2 = nn.ConvTranspose2d(d_model, 64, 2, stride=2)
        self.dec2 = ConvBlock(128, 32)
        self.up1 = nn.ConvTranspose2d(32, 32, 2, stride=2)
        self.dec1 = ConvBlock(64, 32)
        self.final = nn.Conv2d(32, 1, 1)

    def forward(self, tr_out, skips, input_shape):
        e1, e2 = skips
        d2 = self.up2(tr_out)
        if d2.shape[-2:] != e2.shape[-2:]:
            e2_crop = center_crop(e2, d2)
        else:
            e2_crop = e2
        d2 = torch.cat([d2, e2_crop], 1)
        d2 = self.dec2(d2)
        d1 = self.up1(d2)
        if d1.shape[-2:] != e1.shape[-2:]:
            e1_crop = center_crop(e1, d1)
        else:
            e1_crop = e1
        d1 = torch.cat([d1, e1_crop], 1)
        d1 = self.dec1(d1)
        out = self.final(d1)
        if out.shape[-2:] != input_shape[-2:]:
            out = F.interpolate(out, size=input_shape[-2:], mode='bilinear', align_corners=False)
        return out