In [None]:
import torch
import torch.nn as nn
import timm
from torchvision.transforms import Compose, Resize, ToTensor, Normalize


# Vision Transformer-based Image Restoration Model
class ViTImageRestoration(nn.Module):
    def __init__(
        self, img_size=224, patch_size=16, embed_dim=768, num_heads=12, decoder_dim=512
    ):
        super(ViTImageRestoration, self).__init__()
        # Pretrained Vision Transformer as encoder
        self.encoder = timm.create_model(
            "vit_base_patch16_224", pretrained=True, num_classes=0
        )

        # Decoder to reconstruct the missing parts
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(
                embed_dim,
                decoder_dim,
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),
            nn.ReLU(),
            nn.Conv2d(
                decoder_dim, 3, kernel_size=3, padding=1
            ),  # Reconstruct RGB channels
            nn.Tanh(),  # Normalize output to [-1, 1]
        )

    def forward(self, x):
        # Convert input into patch embeddings
        batch_size, _, h, w = x.size()
        x = self.encoder.patch_embed(x)
        # Add positional encoding
        x = self.encoder.pos_drop(x + self.encoder.pos_embed)
        # Pass through Transformer encoder
        x = self.encoder.blocks(x)
        x = self.encoder.norm(x)

        # Reshape for decoder
        x = x.permute(0, 2, 1).view(batch_size, -1, int(h / 16), int(w / 16))
        # Decode back to image
        restored = self.decoder(x)
        return restored


# Loss function: Combined reconstruction + perceptual loss
def loss_function(pred, target):
    l1_loss = nn.L1Loss()(pred, target)
    return l1_loss


# Data Preprocessing
def get_transforms(img_size=224):
    return Compose(
        [
            Resize((img_size, img_size)),
            ToTensor(),
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        ]
    )