In [None]:
class CNNEncoder(nn.Module):
    def __init__(self, input_channels=3, patch_size=8, encoded_size=24):
        super(CNNEncoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),  # Batch Normalization
            nn.ReLU(),
            
            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),  # Batch Normalization
            nn.ReLU(),
            
            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(256),  # Batch Normalization
            nn.ReLU(),
            
            nn.Flatten(),  # Flatten to (batch, 256 * 8 * 8)
            nn.Linear(256 * patch_size * patch_size, 512),
            nn.ReLU(),
            nn.Linear(512, encoded_size)  # Output: (batch, 24)
        )

    def forward(self, x):
        return self.encoder(x)

class CNNDecoder(nn.Module):
    def __init__(self, encoded_size=24, output_channels=3, patch_size=8):
        super(CNNDecoder, self).__init__()
        self.decoder = nn.Sequential(
            nn.Linear(encoded_size, 512),
            nn.ReLU(),
            
            nn.Linear(512, 256 * patch_size * patch_size),
            nn.ReLU(),
            
            nn.Unflatten(1, (256, patch_size, patch_size)),  # (batch, 256, 8, 8)
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(128),  # Batch Normalization
            nn.ReLU(),
            
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(64),  # Batch Normalization
            nn.ReLU(),
            
            nn.ConvTranspose2d(64, output_channels, kernel_size=3, stride=1, padding=1)  # (batch, 3, 8, 8)
        )

    def forward(self, x):
        return self.decoder(x)

class CNNEncoderDecoder(nn.Module):
    def __init__(self, input_channels=3, patch_size=8, encoded_size=24):
        super(CNNEncoderDecoder, self).__init__()
        self.encoder = CNNEncoder(input_channels, patch_size, encoded_size)
        self.decoder = CNNDecoder(encoded_size, input_channels, patch_size)

    def forward(self, patches):
        encoded = self.encoder(patches)
        decoded = self.decoder(encoded)
        return decoded

# Utility Functions to Convert Images to Patches and Vice Versa
def images_to_patches(images, patch_size):
    """
    Convert a batch of images (b, c, w, h) into patches (b * num_patches, c, patch_size, patch_size).
    """
    b, c, w, h = images.shape
    assert w % patch_size == 0 and h % patch_size == 0, "Width and height must be divisible by patch size."
    
    #pdb.set_trace()
    # Reshape the images into patches
    patches = images.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
    patches = patches.permute(0, 2, 3, 1, 4, 5).reshape(-1, c, patch_size, patch_size)

    return patches

def patches_to_images(patches, image_size, patch_size):
    """
    Convert patches (b * num_patches, c, patch_size, patch_size) back into images (b, c, w, h).
    """
    b, c, w, h = image_size
    num_patches_w = w // patch_size  # Number of patches along width
    num_patches_h = h // patch_size  # Number of patches along height
    
    #pdb.set_trace()
    # Reshape patches back to the original image size
    patches = patches.view(b, num_patches_w, num_patches_h, c, patch_size, patch_size)
    patches = patches.permute(0, 3, 1, 4, 2, 5).contiguous()
    images = patches.view(b, c, w, h)

    return images


Unet3 = CNNEncoderDecoder(input_channels=num_channels, patch_size=patch_size, encoded_size=3)
Unet6 = CNNEncoderDecoder(input_channels=num_channels, patch_size=patch_size, encoded_size=6)
Unet12 = CNNEncoderDecoder(input_channels=num_channels, patch_size=patch_size, encoded_size=12)


Unet3.to(device)
Unet6.to(device)
Unet12.to(device)

Unet3.load_state_dict(torch.load("models/unet_3.pt"))
Unet6.load_state_dict(torch.load("models/unet_6.pt"))
Unet12.load_state_dict(torch.load("models/unet_12.pt"))
