In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.rb = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels),
            nn.PReLU(),
            nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(in_channels)
        )
    def forward(self, x):
        out = self.rb(x)
        out = torch.add(x, out)
        return out


class UpSampleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, scale_factor=2):
        super(UpSampleBlock, self).__init__()
        self.upsample_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, padding=1),
            nn.PixelShuffle(scale_factor),
            nn.PReLU()
        )
    def forward(self, x):
        out = self.upsample_block(x)
        return out


class SRResNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual_blocks=16, upscale_factor=4):
        super(SRResNet, self).__init__()
        assert upscale_factor in [2, 4, 8], "Upscale factor must be one of [2, 4, 8]."
        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
            nn.PReLU()
        )
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual_blocks)]
        )
        self.res_out_conv = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        self.upsample_blocks = nn.Sequential(
            *[UpSampleBlock(64, 64) for _ in range(upscale_factor // 2)]
        )
        self.out_conv = nn.Conv2d(64, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.in_conv(x)
        residual = x
        x = self.residual_blocks(x)
        x = self.res_out_conv(x)
        x = torch.add(x, residual)
        x = self.upsample_blocks(x)
        x = self.out_conv(x)
        return x

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F



class Generator(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, num_residual_blocks=5, upscale_factor=4):
        super(Generator, self).__init__()
        self.srresnet = SRResNet(in_channels, out_channels, num_residual_blocks, upscale_factor)

    def forward(self, x):
        return self.srresnet(x)


class Discriminator(nn.Module):
    def __init__(self, img_size=96, in_channels=3, num_filters=64):
        super(Discriminator, self).__init__()
        self.img_size = img_size
        self.in_conv = nn.Sequential(
            nn.Conv2d(in_channels, num_filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True)
        )
        self.model = nn.Sequential(
            self._conv_block(num_filters, 64, kernel_size=3, stride=2, padding=1),
            self._conv_block(64, 128, kernel_size=3, stride=2, padding=1),
            self._conv_block(128, 128, kernel_size=3, stride=2, padding=1),
            self._conv_block(128, 256, kernel_size=3, stride=2, padding=1),
            self._conv_block(256, 256, kernel_size=3, stride=2, padding=1),
            self._conv_block(256, 512, kernel_size=3, stride=2, padding=1),
            self._conv_block(512, 512, kernel_size=3, stride=2, padding=1),
            self._dence_block()
        )
    
    def forward(self, x):
        x = self.in_conv(x)
        x = self.model(x)
        return x

    def _conv_block(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
            nn.LeakyReLU(0.2, inplace=True),
            nn.BatchNorm2d(out_channels) if out_channels != 1 else nn.Identity()
        )
    
    def _dence_block(self):
        return nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512 * 6 * 6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )

In [3]:
input = torch.randn(1, 3, 64, 64)  # Example input tensor
generator = Generator()
output = generator(input)
print(output.shape)  # Should print torch.Size([1, 3, 256,

torch.Size([1, 3, 256, 256])


In [4]:
Discriminator = Discriminator()
disc_output = Discriminator(output)
print(disc_output.shape)  # Should print torch.Size([1, 1, 1, 1])

torch.Size([1, 1])


In [5]:
g_model = torch.load("../logs/srgan_lightning/g_pre_weights.pth")
generator.load_state_dict(g_model)

<All keys matched successfully>