In [32]:
import torch
import torch.nn as nn

class PatchProjection(nn.Module):
    def __init__(self, in_channels, token_dim,kernel_size):
        super().__init__()
        self.conv_layers = nn.Sequential(
            # First conv block
            nn.Conv2d(3, token_dim//4, kernel_size=kernel_size,stride= kernel_size,padding=1),
            nn.ReLU(),
            nn.Conv2d(token_dim//4,token_dim//2 , kernel_size=kernel_size, stride=kernel_size ,padding=1),
            nn.ReLU(),
            nn.Conv2d(token_dim//2, (token_dim*3)//4 , kernel_size=kernel_size, stride=kernel_size ,padding=1),
            nn.ReLU(),
            nn.Conv2d((token_dim*3)//4,token_dim , kernel_size=kernel_size, stride=kernel_size ,padding=2),
            nn.ReLU(),
        )
        
        # Final projection to token_dim
        #self.projection = nn.Linear(256, token_dim)
    
    def forward(self, x):
        # x shape: (batch, channels, height, width)
        features = self.conv_layers(x)  # (batch, 256, 1, 1)
        features = features.view(features.size(0), -1)  # (batch, 256)
        return features


w=64
h=64
div=h//4

batch_size, channels, img_height, img_width = 2816, 3, h, w
token_dim = 64

model = PatchProjection(channels, token_dim, kernel_size=4)
x =torch.randn(batch_size, channels, img_height, img_width)

feature_vector = model(x)  
feature_vector.shape

torch.Size([2816, 64])

In [33]:
unpatcher = nn.Sequential(
            # [1, token_dim, 1, 1] -> [1, 256, 4, 4]
            nn.ConvTranspose2d(token_dim, 256, kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            
            # [1, 256, 4, 4] -> [1, 128, 8, 8]
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            
            # [1, 128, 8, 8] -> [1, 64, 16, 16]
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            
            # [1, 64, 16, 16] -> [1, 32, 32, 32]
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            
            # [1, 32, 32, 32] -> [1, channels, output_height, output_width]
            nn.ConvTranspose2d(32, channels, kernel_size=4, stride=2, padding=1),
            nn.Sigmoid()  # or nn.Sigmoid() depending on your data range
        )
    
output = unpatcher(feature_vector.unsqueeze(2).unsqueeze(3))
print(f"Input shape: {feature_vector.shape}")
print(f"Output shape: {output.shape}")

Input shape: torch.Size([2816, 64])
Output shape: torch.Size([2816, 3, 64, 64])
