<a href="https://colab.research.google.com/github/AnovaYoung/AI-System-for-Image-Restoration-and-Enhancement/blob/Modeling/REAL_ESRGAN_FOR_SUPER_RES.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
import os.path as osp
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset

In [4]:
# Utilize available GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("GPU not available, using CPU.")

tensor = torch.tensor([1.0, 2.0, 3.0]).to(device)
print(f"Tensor is on device: {tensor.device}")


Using GPU: Tesla T4
Tensor is on device: cuda:0


In [8]:
import shutil

# List of paths to force delete
paths_to_delete = [
    "/content/Real-ESRGAN",
    "/content/super_resolution_dataset/augmented_split",
    "/content/super_resolution_dataset/augmented_super_resolution_hr",
    "/content/super_resolution_dataset/augmented_super_resolution_lr",
    "/content/super_resolution_dataset/normalized_super_resolution_hr",
    "/content/super_resolution_dataset/normalized_super_resolution_lr",
    "/content/super_resolution_dataset/super_resolution_hr",
    "/content/super_resolution_dataset/super_resolution_lr"
]

for path in paths_to_delete:
    try:
        shutil.rmtree(path)
        print(f"Successfully deleted: {path}")
    except FileNotFoundError:
        print(f"Path not found (already deleted): {path}")
    except Exception as e:
        print(f"Error deleting {path}: {e}")


Successfully deleted: /content/Real-ESRGAN
Successfully deleted: /content/super_resolution_dataset/augmented_split
Successfully deleted: /content/super_resolution_dataset/augmented_super_resolution_hr
Successfully deleted: /content/super_resolution_dataset/augmented_super_resolution_lr
Successfully deleted: /content/super_resolution_dataset/normalized_super_resolution_hr
Successfully deleted: /content/super_resolution_dataset/normalized_super_resolution_lr
Successfully deleted: /content/super_resolution_dataset/super_resolution_hr
Successfully deleted: /content/super_resolution_dataset/super_resolution_lr


# REAL-ESRGAN MODEL

REAL-ESRGAN MODEL PLANNING

1. Set Up Training Configuration: Define learning rates, batch sizes, and epochs.
2. Model Initialization: Set up the Real-ESRGAN architecture and optimizer.
3. DataLoader Integration: Feed preprocessed train, validation, and test datasets into PyTorch DataLoaders.
4. Training Loop: Implement the training loop with periodic validation.
5. Checkpointing: Save the model weights periodically during training.
6. Testing and Evaluation: Evaluate the trained model on the test set and visualize some results.
7. Save Final Model: Save the final trained model for use in the future.

In [15]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
from tqdm import tqdm
import numpy as np
from torch import nn, optim


# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load the pre-trained Real-ESRGAN model
model = realesrgan(device, scale=4)
model.load_weights('/content/realesrgan_models/RealESRGAN_x4plus.pth')
model.train()  # Set to training mode
model.to(device)

# Loss function and optimizer
criterion = nn.L1Loss()  # Can be replaced with another loss (e.g., L2, Perceptual)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# Fine-tuning parameters
epochs = 10
checkpoint_dir = "/content/realesrgan_checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)

# Training loop
for epoch in range(epochs):
    model.train()
    train_loss = 0.0
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}")

    for lr_batch, hr_batch in progress_bar:
        lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)

        # Forward pass
        optimizer.zero_grad()
        sr_batch = model(lr_batch)

        # Loss calculation
        loss = criterion(sr_batch, hr_batch)
        train_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        progress_bar.set_postfix({"Loss": loss.item()})

    avg_train_loss = train_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{epochs}, Avg Train Loss: {avg_train_loss:.4f}")

    # Validation
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for lr_batch, hr_batch in val_loader:
            lr_batch, hr_batch = lr_batch.to(device), hr_batch.to(device)
            sr_batch = model(lr_batch)
            val_loss += criterion(sr_batch, hr_batch).item()

    avg_val_loss = val_loss / len(val_loader)
    print(f"Epoch {epoch+1}/{epochs}, Avg Val Loss: {avg_val_loss:.4f}")

    # Save model checkpoint
    torch.save(model.state_dict(), f"{checkpoint_dir}/realesrgan_epoch{epoch+1}.pth")
    print(f"Model checkpoint saved for epoch {epoch+1}.")

print("Fine-tuning complete!")


ModuleNotFoundError: No module named 'torchvision.transforms.functional_tensor'