In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os
import numpy as np
import torchvision.models as models
import matplotlib.pyplot as plt
from torch.cuda.amp import autocast, GradScaler  # Import mixed precision training tools

# 1. Generator Class
class Generator(nn.Module):
    def __init__(self, num_channels=3, num_residual_blocks=16):
        super(Generator, self).__init__()
        
        # Initial convolution layer
        self.conv1 = nn.Conv2d(num_channels, 64, kernel_size=9, stride=1, padding=4, bias=True)
        self.conv1 = self.conv1.to(dtype=torch.float32)  # Explicitly set to float32
        
        # Residual blocks
        self.residual_blocks = nn.Sequential(
            *[ResidualBlock(64) for _ in range(num_residual_blocks)]
        )
        
        # Upsampling layers (subpixel convolution)
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1, bias=True).to(dtype=torch.float32),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 256, kernel_size=3, stride=1, padding=1, bias=True).to(dtype=torch.float32),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, num_channels, kernel_size=3, stride=1, padding=1, bias=True).to(dtype=torch.float32)
        )

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = self.residual_blocks(x)
        x = self.upsample(x)
        return x

# 2. Residual Block used in Generator
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv1 = self.conv1.to(dtype=torch.float32)  # Explicitly set to float32
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1, bias=True)
        self.conv2 = self.conv2.to(dtype=torch.float32)  # Explicitly set to float32
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        residual = x
        x = self.relu(self.conv1(x))
        x = self.conv2(x)
        return x + residual

# 3. Discriminator (PatchGAN)
class Discriminator(nn.Module):
    def __init__(self, num_channels=3):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Conv2d(num_channels, 64, kernel_size=3, stride=1, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=True),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=3, stride=1, padding=1, bias=True)
        )

    def forward(self, x):
        return self.model(x)

# 4. VGG for perceptual loss
class VGG19(nn.Module):
    def __init__(self):
        super(VGG19, self).__init__()
        vgg = models.vgg19(pretrained=True).features
        self.slice = nn.Sequential(*[vgg[i] for i in range(16)])  # Up to relu_4_2

    def forward(self, x):
        return self.slice(x)

# 5. Dataset Class for HR and LR pairs
class SuperResolutionDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None):
        self.hr_dir = hr_dir
        self.lr_dir = lr_dir
        self.transform = transform
        self.hr_images = os.listdir(hr_dir)

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, idx):
        hr_image = Image.open(os.path.join(self.hr_dir, self.hr_images[idx])).convert('RGB')
        lr_image = Image.open(os.path.join(self.lr_dir, self.hr_images[idx])).convert('RGB')
        
        if self.transform:
            hr_image = self.transform(hr_image)
            lr_image = self.transform(lr_image)
        
        return lr_image, hr_image

# 6. Define Loss Functions
def adversarial_loss(D, real, fake):
    real_loss = torch.mean((D(real) - 1) ** 2)
    fake_loss = torch.mean(D(fake) ** 2)
    return (real_loss + fake_loss) / 2

def content_loss(x, y):
    return torch.mean((x - y) ** 2)

def perceptual_loss(vgg, x, y):
    x_features = vgg(x)
    y_features = vgg(y)
    return content_loss(x_features, y_features)

# 7. Training Loop with Image Resizing to Match HR Size
def train(generator, discriminator, vgg, dataloader, num_epochs=50, lr=0.0002):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    generator.to(device)
    discriminator.to(device)
    vgg.to(device)

    optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.9, 0.999))
    optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.9, 0.999))

    scaler = GradScaler()

    for epoch in range(num_epochs):
        for i, (lr, hr) in enumerate(dataloader):
            lr, hr = lr.to(device), hr.to(device)

            # Train Discriminator
            optimizer_d.zero_grad()
            with autocast():  # Automatic mixed precision
                fake = generator(lr)
                real_loss = adversarial_loss(discriminator, hr, fake.detach())
                fake_loss = adversarial_loss(discriminator, hr, fake)
                d_loss = (real_loss + fake_loss) / 2
            scaler.scale(d_loss).backward(retain_graph=True)  # Retain graph for the next backward pass
            scaler.step(optimizer_d)
            scaler.update()

            # Train Generator
            optimizer_g.zero_grad()
            with autocast():  # Automatic mixed precision
                fake_resized = nn.functional.interpolate(fake, size=hr.shape[2:], mode='bilinear', align_corners=False)
                
                g_loss_adv = adversarial_loss(discriminator, hr, fake_resized)
                g_loss_content = content_loss(fake_resized, hr)
                g_loss_perceptual = perceptual_loss(vgg, fake_resized, hr)
                g_loss = g_loss_adv + g_loss_content + 0.006 * g_loss_perceptual
            scaler.scale(g_loss).backward()  # No need to retain graph here as it's the last backward pass
            scaler.step(optimizer_g)
            scaler.update()

            # Print loss every 100 steps
            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(dataloader)}], "
                      f"D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

            # Free unused memory after each step
            torch.cuda.empty_cache()

        # Save the model after each epoch
        save_model(generator, discriminator, optimizer_g, optimizer_d, epoch, checkpoint_path=f'model_epoch_{epoch+1}.pth')

# 8. Save Model Function
def save_model(generator, discriminator, optimizer_g, optimizer_d, epoch, checkpoint_path):
    checkpoint = {
        'epoch': epoch,
        'generator_state_dict': generator.state_dict(),
        'discriminator_state_dict': discriminator.state_dict(),
        'optimizer_g_state_dict': optimizer_g.state_dict(),
        'optimizer_d_state_dict': optimizer_d.state_dict()
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Model saved to {checkpoint_path}")

# 9. Dataset and DataLoader
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = SuperResolutionDataset(hr_dir=r'dataset\small train\high_res', 
                                       lr_dir=r'dataset\small train\low_res', 
                                       transform=transform)

# Reduced batch size to 8 to manage memory
train_dataloader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)

# 10. Initialize models and start training
generator = Generator()
discriminator = Discriminator()
vgg = VGG19()

train(generator, discriminator, vgg, train_dataloader, num_epochs=5, lr=0.0002)




KeyboardInterrupt: 

In [6]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToPILImage
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import os

# 1. Load the Saved Generator Model
def load_generator(checkpoint_path, device='cuda'):
    generator = Generator()
    checkpoint = torch.load(checkpoint_path, map_location=device)
    generator.load_state_dict(checkpoint['generatort'])  # Ensure the key matches the saved state
    generator.to(device).eval()  # Move to device and set to evaluation mode
    print(f"Loaded generator model from {checkpoint_path}")
    return generator

# 2. Dataset Class for Testing
class TestDataset(Dataset):
    def __init__(self, lr_dir, transform=None):
        self.lr_dir = lr_dir
        self.transform = transform
        self.lr_images = os.listdir(lr_dir)

    def __len__(self):
        return len(self.lr_images)

    def __getitem__(self, idx):
        lr_image = Image.open(os.path.join(self.lr_dir, self.lr_images[idx])).convert('RGB')
        if self.transform:
            lr_image = self.transform(lr_image)
        return lr_image, self.lr_images[idx]

# 3. Test Function
def test(generator, dataloader, device='cuda', output_dir='output_images'):
    os.makedirs(output_dir, exist_ok=True)
    to_pil = ToPILImage()

    generator.eval()
    with torch.no_grad():
        for i, (lr, filenames) in enumerate(dataloader):
            lr = lr.to(device)
            fake = generator(lr)

            for j in range(len(filenames)):
                lr_image = to_pil(lr[j].cpu().detach())
                sr_image = to_pil(fake[j].cpu().detach())

                # Save images
                lr_image.save(os.path.join(output_dir, f"LR_{filenames[j]}"))
                sr_image.save(os.path.join(output_dir, f"SR_{filenames[j]}"))

                # Display the first batch
                if i == 0 and j == 0:
                    plt.figure(figsize=(10, 5))
                    plt.subplot(1, 2, 1)
                    plt.title("Low-Resolution")
                    plt.imshow(lr_image)
                    plt.axis("off")

                    plt.subplot(1, 2, 2)
                    plt.title("Super-Resolution")
                    plt.imshow(sr_image)
                    plt.axis("off")
                    plt.show()

# 4. Set Paths and Load Model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
checkpoint_path = 'model_epoch_5.pth'  # Path to the trained generator checkpoint
lr_test_dir = r'C:\ESRGANs\input'  # Path to low-resolution test images

# 5. Transform for Test Dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

# 6. Create Test Dataset and DataLoader
test_dataset = TestDataset(lr_dir=lr_test_dir, transform=transform)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0)

# 7. Load Generator Model
generator = load_generator(checkpoint_path, device=device)

# 8. Run Testing
test(generator, test_dataloader, device=device, output_dir='output_images')


KeyError: 'generatort'