<a href="https://colab.research.google.com/github/Arminh388/Generative-Photo-Reconstruction/blob/Armin_SRGANTestingArea/GrPEN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

install requirements

In [None]:
!pip install torch torchvision opencv-python matplotlib



Neural Network Class

In [None]:
import torch
from torch import nn
from torchvision.models import vgg19
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import os
import cv2
from PIL import Image

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(in_channels)

    def forward(self, x):
        identity = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        return self.relu(out + identity)

# GPEN Generator
class GPENGenerator(nn.Module):
    def __init__(self, in_channels=3, num_residuals=8):
        super(GPENGenerator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True)
        )
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residuals)]
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, in_channels, kernel_size=3, stride=1, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.residual_blocks(x)
        x = self.decoder(x)
        return x

# GPEN Discriminator
class GPENDiscriminator(nn.Module):
    def __init__(self, in_channels=3):
        super(GPENDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Flatten(),
            nn.Linear(512 * 4 * 4, 1),
        )

    def forward(self, x):
        return self.model(x)


In [None]:
class FaceDataset(Dataset):
    def __init__(self, lr_dir, hr_dir, transform=None):
        self.lr_dir = lr_dir
        self.hr_dir = hr_dir
        self.transform = transform
        self.lr_images = os.listdir(lr_dir)
        self.hr_images = os.listdir(hr_dir)

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_images[idx])).convert("RGB")
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_images[idx])).convert("RGB")

        if self.transform:
            lr_image = self.transform(lr_image)
            hr_image = self.transform(hr_image)

        return lr_image, hr_image

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])


Traning Loop

In [None]:
def train(dataloader, generator, discriminator, optimizer_g, optimizer_d, criterion, num_epochs=50):
    for epoch in range(num_epochs):
        for i, (lr, hr) in enumerate(dataloader):
            lr, hr = lr.cuda(), hr.cuda()

            # Train Discriminator
            optimizer_d.zero_grad()
            fake_hr = generator(lr).detach()
            real_pred = discriminator(hr)
            fake_pred = discriminator(fake_hr)

            d_loss = criterion(real_pred, torch.ones_like(real_pred)) + criterion(fake_pred, torch.zeros_like(fake_pred))
            d_loss.backward()
            optimizer_d.step()

            # Train Generator
            optimizer_g.zero_grad()
            fake_hr = generator(lr)
            fake_pred = discriminator(fake_hr)
            g_loss = criterion(fake_pred, torch.ones_like(fake_pred)) + nn.MSELoss()(fake_hr, hr)
            g_loss.backward()
            optimizer_g.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i}/{len(dataloader)}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

# Initialize Models
generator = GPENGenerator().cuda()
discriminator = GPENDiscriminator().cuda()

# Optimizers and Loss
optimizer_g = torch.optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.999))
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.999))
criterion = nn.BCEWithLogitsLoss()

# Dataset and DataLoader
lr_dir = "path_to_lr_images"
hr_dir = "path_to_hr_images"
dataset = FaceDataset(lr_dir, hr_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

# Train the Model
train(dataloader, generator, discriminator, optimizer_g, optimizer_d, criterion, num_epochs=50)

# Save Models
torch.save(generator.state_dict(), "gpen_generator.pth")
torch.save(discriminator.state_dict(), "gpen_discriminator.pth")


Save State

In [None]:
generator.load_state_dict(torch.load("gpen_generator.pth"))
discriminator.load_state_dict(torch.load("gpen_discriminator.pth"))

train(dataloader, generator, discriminator, optimizer_g, optimizer_d, criterion, num_epochs=10)


Evaluation

In [None]:
from torchvision.utils import save_image

generator.eval()
with torch.no_grad():
    for i, (lr, _) in enumerate(dataloader):
        lr = lr.cuda()
        restored_faces = generator(lr)
        save_image(restored_faces, f"restored_face_{i}.png", normalize=True)
