In [None]:
# Clone Real-ESRGAN and enter the Real-ESRGAN
!git clone https://github.com/xinntao/Real-ESRGAN.git
%cd Real-ESRGAN
# Set up the environment
!pip install basicsr
!pip install facexlib
!pip install gfpgan
!pip install -r requirements.txt
!python setup.py develop

In [None]:
# Fix the import error in basicsr/data/degradations.py
!sed -i "s/from torchvision.transforms.functional_tensor import rgb_to_grayscale/from torchvision.transforms.functional import rgb_to_grayscale/" /usr/local/lib/python3.12/dist-packages/basicsr/data/degradations.py

In [None]:
# Download the pre-trained models
!wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth -P experiments/pretrained_models

In [None]:
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

# Define model
model = RRDBNet(
    num_in_ch=3, num_out_ch=3, num_feat=64,
    num_block=23, num_grow_ch=32, scale=2
)

# Load pretrained weights
model_path = 'experiments/pretrained_models/RealESRGAN_x2plus.pth'

upsampler = RealESRGANer(
    scale=2,
    model_path=model_path,
    model=model,
    tile=0,
    tile_pad=10,
    pre_pad=0,
    half=True  # use FP16 if GPU supports
)


In [None]:
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import glob, os

# Custom Dataset
class IphoneToDslrDataset(Dataset):
    def __init__(self, iphone_dir, dslr_dir, transform=None):
        self.iphone_imgs = sorted(glob.glob(os.path.join(iphone_dir, "*.jpg")))
        self.dslr_imgs = sorted(glob.glob(os.path.join(dslr_dir, "*.jpg")))
        self.transform = transform

    def __len__(self):
        return min(len(self.iphone_imgs), len(self.dslr_imgs))

    def __getitem__(self, idx):
        inp = Image.open(self.iphone_imgs[idx]).convert("RGB")
        tgt = Image.open(self.dslr_imgs[idx]).convert("RGB")
        if self.transform:
            inp, tgt = self.transform(inp), self.transform(tgt)
        return inp, tgt

# Data
transform = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor()
])
train_dataset = IphoneToDslrDataset("/content/Real-ESRGAN/datasets/scannet++/iphone", "/content/Real-ESRGAN/datasets/scannet++/canon", transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

# Optimizer + Loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
loss_fn = torch.nn.L1Loss()

# Training loop (transfer learning)
model.train()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(2):  # keep small for Colab
    for i, (inp, tgt) in enumerate(train_loader):
        inp, tgt = inp.to(device).half(), tgt.to(device).half() # Cast input and target to half precision
        out = model(inp)
        loss = loss_fn(out, tgt)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 5 == 0:
            print(f"Epoch {epoch} Batch {i} Loss {loss.item():.4f}")