<a href="https://colab.research.google.com/github/Aagam11/Aagam11.github.io/blob/main/Untitled1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [38]:
%%writefile gan_super_resolution.py
# Paste the entire gan_super_resolution.py script here
import os
import zipfile
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import requests
import urllib3

# ------------------------------
# 1. Download and Extract DIV2K Dataset using requests
# ------------------------------

def download_and_extract_div2k():
    # Disable SSL warnings for insecure requests
    urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

    url = 'http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip'
    zip_path = 'DIV2K_train_HR.zip'
    extract_folder = 'DIV2K_train_HR'

    if not os.path.exists(zip_path):
        print("Downloading DIV2K high-resolution images...")
        r = requests.get(url, verify=False)
        with open(zip_path, 'wb') as f:
            f.write(r.content)
        print("Download complete.")
    else:
        print("Zip file already exists.")

    if not os.path.exists(extract_folder):
        print("Extracting the dataset...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall()  # Extracts into the current directory
        print("Extraction complete.")
    else:
        print("Dataset already extracted.")

# ------------------------------
# 2. Define the Custom Dataset with Fixed Crop Size
# ------------------------------

class SuperResolutionDataset(Dataset):
    """
    Custom dataset for super-resolution.
    Assumes hr_dir contains high-resolution images.
    A random crop of size crop_size x crop_size is taken from the HR image,
    then downsampled (by factor of 4) to generate the LR image.
    """
    def __init__(self, hr_dir, crop_size=128,
                 transform_hr=None, transform_lr=None):
        self.hr_dir = hr_dir
        self.hr_images = os.listdir(hr_dir)
        self.crop_size = crop_size
        self.transform_hr = transform_hr if transform_hr is not None else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])
        ])
        self.transform_lr = transform_lr if transform_lr is not None else transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5])
        ])

    def __len__(self):
        return len(self.hr_images)

    def __getitem__(self, index):
        hr_image_path = os.path.join(self.hr_dir, self.hr_images[index])
        hr_image = Image.open(hr_image_path).convert("RGB")

        # Ensure the image is large enough for cropping.
        if hr_image.width < self.crop_size or hr_image.height < self.crop_size:
            hr_image = hr_image.resize((max(hr_image.width, self.crop_size),
                                        max(hr_image.height, self.crop_size)),
                                       Image.BICUBIC)
        # Randomly crop a patch from the HR image.
        crop_transform = transforms.RandomCrop(self.crop_size)
        hr_crop = crop_transform(hr_image)
        # Downsample the cropped HR patch to get the LR image.
        lr_image = hr_crop.resize((self.crop_size // 4, self.crop_size // 4), Image.BICUBIC)

        hr_tensor = self.transform_hr(hr_crop)
        lr_tensor = self.transform_lr(lr_image)
        return lr_tensor, hr_tensor

# ------------------------------
# 3. Define the GAN Model Components
# ------------------------------

class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)

    def forward(self, x):
        residual = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += residual
        return out

class Generator(nn.Module):
    def __init__(self, num_residual_blocks=4):  # Reduced number of residual blocks for speed
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, padding=4)
        self.relu = nn.ReLU(inplace=True)
        self.residuals = nn.Sequential(*[ResidualBlock(64) for _ in range(num_residual_blocks)])
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.upsample = nn.Sequential(
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 256, kernel_size=3, padding=1),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True)
        )
        self.conv3 = nn.Conv2d(64, 3, kernel_size=9, padding=4)

    def forward(self, x):
        out1 = self.relu(self.conv1(x))
        out = self.residuals(out1)
        out = self.bn2(self.conv2(out))
        out = out + out1  # Global skip connection
        out = self.upsample(out)
        out = self.conv3(out)
        return out

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        def disc_block(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2, inplace=True)
            )
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            disc_block(64, 64, 2),
            disc_block(64, 128, 1),
            disc_block(128, 128, 2),
            disc_block(128, 256, 1),
            disc_block(256, 256, 2),
            disc_block(256, 512, 1),
            disc_block(512, 512, 2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(1024, 1, kernel_size=1)
        )

    def forward(self, x):
        out = self.model(x)
        return torch.sigmoid(out.view(out.size(0), -1))

# ------------------------------
# 4. Training Function with Mixed Precision (if GPU available)
# ------------------------------

def train():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Using device:", device)

    # Point hr_dir directly to the extracted DIV2K folder.
    hr_dir = 'DIV2K_train_HR'

    transform_hr = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    transform_lr = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset = SuperResolutionDataset(hr_dir, crop_size=128,
                                     transform_hr=transform_hr, transform_lr=transform_lr)
    dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=2)

    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    criterion_GAN = nn.BCELoss()
    criterion_content = nn.MSELoss()

    optimizer_G = optim.Adam(generator.parameters(), lr=1e-4)
    optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4)

    num_epochs = 10  # Fewer epochs for faster prototyping
    scaler_G = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
    scaler_D = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None

    for epoch in range(num_epochs):
        for i, (lr_imgs, hr_imgs) in enumerate(dataloader):
            lr_imgs = lr_imgs.to(device)
            hr_imgs = hr_imgs.to(device)

            valid = torch.ones((lr_imgs.size(0), 1), device=device)
            fake = torch.zeros((lr_imgs.size(0), 1), device=device)

            # ---------------------
            #  Train Generator
            # ---------------------
            optimizer_G.zero_grad()
            with torch.cuda.amp.autocast(enabled=(scaler_G is not None)):
                sr_imgs = generator(lr_imgs)
                pred_fake = discriminator(sr_imgs)
                loss_GAN = criterion_GAN(pred_fake, valid)
                loss_content = criterion_content(sr_imgs, hr_imgs)
                loss_G = loss_content + 1e-3 * loss_GAN
            if scaler_G is not None:
                scaler_G.scale(loss_G).backward()
                scaler_G.step(optimizer_G)
                scaler_G.update()
            else:
                loss_G.backward()
                optimizer_G.step()

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()
            with torch.cuda.amp.autocast(enabled=(scaler_D is not None)):
                pred_real = discriminator(hr_imgs)
                loss_real = criterion_GAN(pred_real, valid)
                pred_fake = discriminator(sr_imgs.detach())
                loss_fake = criterion_GAN(pred_fake, fake)
                loss_D = (loss_real + loss_fake) / 2
            if scaler_D is not None:
                scaler_D.scale(loss_D).backward()
                scaler_D.step(optimizer_D)
                scaler_D.update()
            else:
                loss_D.backward()
                optimizer_D.step()

            if i % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}] Batch {i} | Loss_D: {loss_D.item():.4f} | Loss_G: {loss_G.item():.4f}")

    print("Training complete. Saving generator model as 'generator.pth'.")
    torch.save(generator.state_dict(), "generator.pth")

# ------------------------------
# 5. Colab-Friendly Inference Function
# ------------------------------

def run_colab_inference():
    from google.colab import files
    import io
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    generator = Generator().to(device)
    if not os.path.exists("generator.pth"):
        print("Error: 'generator.pth' not found. Please run in train mode first.")
        return
    generator.load_state_dict(torch.load("generator.pth", map_location=device))
    generator.eval()

    # print("Please upload an image file...")
    image_path = "/content/image1.jpg" # Local path to image in Google Colab
    input_image = Image.open(image_path).convert("RGB")
    # uploaded = files.upload()
    # response = requests.get(image_url, stream=True)
    # response.raise_for_status()  # Raise an exception for bad responses
    # input_image = Image.open(bytes(response.content)).convert("RGB")
    # for fn in uploaded.keys():
    #     input_image = Image.open(io.BytesIO(uploaded[fn])).convert("RGB")
    transform_input = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    input_tensor = transform_input(input_image).unsqueeze(0).to(device)
    with torch.no_grad():
        sr_tensor = generator(input_tensor)
    sr_tensor = sr_tensor.squeeze(0).cpu()
    sr_tensor = sr_tensor * 0.5 + 0.5  # Denormalize
    sr_tensor = torch.clamp(sr_tensor, 0, 1)
    sr_image = transforms.ToPILImage()(sr_tensor)

    # Display original and super-resolved images using matplotlib
    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(input_image)
    axes[0].set_title("Original Image")
    axes[0].axis("off")
    axes[1].imshow(sr_image)
    axes[1].set_title("Super-Resolved Image")
    axes[1].axis("off")
    plt.savefig('output_image2.png')  # Save the figure to a file named 'output_image.png'
    plt.show()

# ------------------------------
# 6. Main Execution
# ------------------------------

if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description="GAN-based Super Resolution for Google Colab")
    parser.add_argument("--mode", type=str, default="colab", choices=["train", "colab"],
                        help="Select 'train' to train the model or 'colab' to run inference in Colab")
    args = parser.parse_args()

    if args.mode == "train":
        download_and_extract_div2k()
        train()
    elif args.mode == "colab":
        run_colab_inference()

Overwriting gan_super_resolution.py


In [13]:
!python3 gan_super_resolution.py --mode train

Zip file already exists.
Dataset already extracted.
Using device: cpu
  with torch.cuda.amp.autocast(enabled=(scaler_G is not None)):
  with torch.cuda.amp.autocast(enabled=(scaler_D is not None)):
Epoch [1/10] Batch 0 | Loss_D: 0.6951 | Loss_G: 0.3067
Epoch [2/10] Batch 0 | Loss_D: 0.4752 | Loss_G: 0.0799
Epoch [3/10] Batch 0 | Loss_D: 0.3730 | Loss_G: 0.0519
Epoch [4/10] Batch 0 | Loss_D: 0.2631 | Loss_G: 0.0450
Epoch [5/10] Batch 0 | Loss_D: 0.1556 | Loss_G: 0.0319
Epoch [6/10] Batch 0 | Loss_D: 0.1089 | Loss_G: 0.0257
Epoch [7/10] Batch 0 | Loss_D: 0.9848 | Loss_G: 0.0321
Epoch [8/10] Batch 0 | Loss_D: 0.0694 | Loss_G: 0.0242
Epoch [9/10] Batch 0 | Loss_D: 0.1346 | Loss_G: 0.0239
Epoch [10/10] Batch 0 | Loss_D: 0.2354 | Loss_G: 0.0201
Training complete. Saving generator model as 'generator.pth'.


In [4]:
!pip install requests



In [39]:
!python3 gan_super_resolution.py --mode colab

  generator.load_state_dict(torch.load("generator.pth", map_location=device))
Figure(1000x500)
