In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import ToTensor, Resize
from PIL import Image
import os
import glob

# Residual Block Definition
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        out = self.relu(self.conv1(x))
        out = self.conv2(out)
        out += identity
        return self.relu(out)

# Super Resolution Network with Residual Block
class SuperResolutionNet(nn.Module):
    def __init__(self, upscale_factor):
        super(SuperResolutionNet, self).__init__()

        self.upscale_factor = upscale_factor

        # Initial convolution layer
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, stride=1, padding=2)

        # Residual block
        self.residual_block = ResidualBlock(64)

        # Second convolution layer
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)

        # PixelShuffle layer
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor)

        # Output convolution layer
        self.conv3 = nn.Conv2d(in_channels=32 // (upscale_factor ** 2), out_channels=3, kernel_size=3, stride=1, padding=1)

    def forward(self, x):
        # Apply initial convolution layer
        x = F.relu(self.conv1(x))

        # Apply residual block
        x = self.residual_block(x)

        # Apply second convolution layer
        x = F.relu(self.conv2(x))

        # Apply PixelShuffle layer
        x = self.pixel_shuffle(x)

        # Apply final convolution layer
        x = self.conv3(x)
        return x

# Dataset creation by lowering resolution
def create_low_res_dataset(image_dir, low_res_dir, scale_factor):
    os.makedirs(low_res_dir, exist_ok=True)
    transform = ToTensor()

    for img_path in glob.glob(os.path.join(image_dir, '*.jpg')):
        img = Image.open(img_path).convert('RGB')
        high_res_img = transform(img)
        low_res_img = Resize((high_res_img.shape[1] // scale_factor, high_res_img.shape[2] // scale_factor))(img)
        low_res_img.save(os.path.join(low_res_dir, os.path.basename(img_path)))

# Loss function: Mean Squared Error
def loss_function(output, target):
    return F.mse_loss(output, target)

# Example usage
if __name__ == "__main__":
    # Define the network with an upscale factor of 2
    net = SuperResolutionNet(upscale_factor=2)

    # Create a random input tensor with shape (batch_size, channels, height, width)
    input_tensor = torch.randn(1, 3, 24, 24)  # Example input size: (1, 3, 24, 24)

    # Forward pass
    output_tensor = net(input_tensor)

    print(f"Input tensor shape: {input_tensor.shape}")
    print(f"Output tensor shape: {output_tensor.shape}")

    # Example dataset creation
    image_dir = 'path/to/high_res_images'
    low_res_dir = 'path/to/low_res_images'
    scale_factor = 2
    create_low_res_dataset(image_dir, low_res_dir, scale_factor)

    # Example loss calculation
    high_res_tensor = torch.randn(1, 3, 48, 48)  # Example high-resolution target
    loss = loss_function(output_tensor, high_res_tensor)
    print(f"Loss: {loss.item()}")
