In [None]:
# ─── 1. Install Dependencies ──────────────────────────────────────────────
!pip install torch torchvision pillow scikit-image tqdm

# ─── 2. Download & Extract DIV2K ─────────────────────────────────────────
import os
if not os.path.isdir('DIV2K_train_HR'):
    !wget -q https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip
    !unzip -q DIV2K_train_HR.zip
    !rm DIV2K_train_HR.zip

# ─── 3. Preprocess: Create LR Images ──────────────────────────────────────
from PIL import Image
def create_lr(hr_dir='DIV2K_train_HR', lr_dir='DIV2K_train_LR', scale=4):
    os.makedirs(lr_dir, exist_ok=True)
    for fn in os.listdir(hr_dir):
        if not fn.lower().endswith(('png','jpg','jpeg')):
            continue
        hr = Image.open(f'{hr_dir}/{fn}').convert('RGB')
        w, h = hr.size
        lr = hr.resize((w//scale, h//scale), Image.BICUBIC)
        lr = lr.resize((w, h), Image.BICUBIC)
        lr.save(f'{lr_dir}/{fn}')
create_lr()

# ─── 4. Imports & Dataset ─────────────────────────────────────────────────
import torch, torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from PIL import Image
from tqdm import tqdm

class DIV2KDataset(Dataset):
    def __init__(self, hr_dir, lr_dir, transform=None):
        self.hr_dir, self.lr_dir = hr_dir, lr_dir
        self.fns = [f for f in os.listdir(hr_dir) if f.lower().endswith(('png','jpg'))]
        self.transform = transform or Compose([ToTensor(), Normalize((0.5,)*3,(0.5,)*3)])
    def __len__(self):
        return len(self.fns)
    def __getitem__(self, i):
        hr = Image.open(f'{self.hr_dir}/{self.fns[i]}').convert('RGB')
        lr = Image.open(f'{self.lr_dir}/{self.fns[i]}').convert('RGB')
        return self.transform(lr), self.transform(hr)

# ─── 5. Model Definitions ─────────────────────────────────────────────────
class ResidualBlock(nn.Module):
    def __init__(self, c=64):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c, c, 3, 1, 1), nn.BatchNorm2d(c), nn.PReLU(),
            nn.Conv2d(c, c, 3, 1, 1), nn.BatchNorm2d(c)
        )
    def forward(self, x):
        return x + self.net(x)

class UpsampleBlock(nn.Module):
    def __init__(self, c, scale=2):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c, c * scale * scale, 3, 1, 1),
            nn.PixelShuffle(scale),
            nn.PReLU()
        )
    def forward(self, x):
        return self.net(x)

class Generator(nn.Module):
    def __init__(self, num_res=16, up=4):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 9, 1, 4); self.pre1 = nn.PReLU()
        self.res   = nn.Sequential(*[ResidualBlock(64) for _ in range(num_res)])
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1); self.bn2 = nn.BatchNorm2d(64)
        ups = []
        for _ in range(up // 2):
            ups.append(UpsampleBlock(64, 2))
        self.ups = nn.Sequential(*ups)
        self.conv3 = nn.Conv2d(64, 3, 9, 1, 4)
    def forward(self, x):
        out1 = self.pre1(self.conv1(x))
        out  = self.res(out1)
        out  = self.bn2(self.conv2(out)) + out1
        out  = self.ups(out)
        out  = self.conv3(out)
        return (torch.tanh(out) + 1) / 2

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        layers = []
        def b(in_c, out_c, s):
            layers.extend([nn.Conv2d(in_c, out_c, 3, s, 1), nn.BatchNorm2d(out_c), nn.LeakyReLU(0.2)])
        b(3, 64, 1); b(64, 64, 2); b(64, 128, 1); b(128, 128, 2)
        b(128, 256, 1); b(256, 256, 2); b(256, 512, 1); b(512, 512, 2)
        layers += [
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        ]
        self.net = nn.Sequential(*layers)
    def forward(self, x):
        return self.net(x)

class DenseResidualBlock(nn.Module):
    def __init__(self, in_c, gr=32):
        super().__init__()
        self.layers = nn.ModuleList()
        for i in range(5):
            self.layers.append(nn.Sequential(
                nn.Conv2d(in_c + i * gr, gr, 3, 1, 1),
                nn.LeakyReLU(0.2)
            ))
        self.conv1x1 = nn.Conv2d(in_c + 5 * gr, in_c, 1, 1, 0)
    def forward(self, x):
        feats = [x]
        for layer in self.layers:
            out = layer(torch.cat(feats, 1)); feats.append(out)
        out = torch.cat(feats, 1)
        return x + 0.2 * self.conv1x1(out)

class RRDB(nn.Module):
    def __init__(self, in_c=64):
        super().__init__()
        self.rdb1 = DenseResidualBlock(in_c)
        self.rdb2 = DenseResidualBlock(in_c)
        self.rdb3 = DenseResidualBlock(in_c)
    def forward(self, x):
        return x + 0.2 * self.rdb3(self.rdb2(self.rdb1(x)))

class EnhancedGenerator(Generator):
    def __init__(self, rrdb_blocks=23):
        super().__init__(num_res=0, up=4)
        self.conv1 = nn.Conv2d(3, 64, 9, 1, 4); self.pre1 = nn.PReLU()
        self.trunk = nn.Sequential(*[RRDB(64) for _ in range(rrdb_blocks)])
        self.trunk_conv = nn.Conv2d(64, 64, 3, 1, 1)
        ups = []
        for _ in range(2):
            ups += [nn.Conv2d(64, 256, 3, 1, 1), nn.PixelShuffle(2), nn.PReLU()]
        self.ups = nn.Sequential(*ups)
        self.conv_last = nn.Conv2d(64, 3, 9, 1, 4)
    def forward(self, x):
        out1 = self.pre1(self.conv1(x))
        out  = self.trunk(out1)
        out  = self.trunk_conv(out) + out1
        out  = self.ups(out)
        out  = self.conv_last(out)
        return (torch.tanh(out) + 1) / 2

# ─── 6. Training Functions (with epoch prints) ─────────────────────────────
def train_srgan(hr_dir, lr_dir, epochs=30, bs=16, lr=1e-4, save_every=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ds = DIV2KDataset(hr_dir, lr_dir)
    dl = DataLoader(ds, bs, shuffle=True, num_workers=4)
    G, D = Generator().to(device), Discriminator().to(device)
    optG = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.9,0.999))
    optD = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.9,0.999))
    mse = nn.MSELoss(); bce = nn.BCEWithLogitsLoss()

    for e in range(1, epochs+1):
        G.train(); D.train()
        for lr_imgs, hr_imgs in tqdm(dl, desc=f"SRGAN Ep {e}/{epochs}"):
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            valid = torch.ones(len(lr_imgs), 1, device=device)
            fake  = torch.zeros(len(lr_imgs), 1, device=device)

            # Discriminator step
            gen_hr_detached = G(lr_imgs).detach()
            lossD = (bce(D(hr_imgs), valid) + bce(D(gen_hr_detached), fake)) / 2
            optD.zero_grad(); lossD.backward(); optD.step()

            # Generator step
            gen_hr = G(lr_imgs)
            loss_content = mse(gen_hr, hr_imgs)
            loss_adv     = bce(D(gen_hr), valid)
            loss_pix     = mse(gen_hr, hr_imgs)
            lossG        = loss_content + 1e-3 * loss_adv + 2e-6 * loss_pix
            optG.zero_grad(); lossG.backward(); optG.step()

        # Show epoch progress
        print(f"Completed epoch {e}/{epochs}")

        if e % save_every == 0:
            torch.save(G.state_dict(), f'gen_{e}.pth')
            torch.save(D.state_dict(), f'disc_{e}.pth')

def train_enhanced(hr_dir, lr_dir, sr_ckpt=None, epochs=30, bs=16, lr=1e-4, save_every=10):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    ds = DIV2KDataset(hr_dir, lr_dir)
    dl = DataLoader(ds, bs, shuffle=True, num_workers=4)
    G = EnhancedGenerator().to(device)
    if sr_ckpt:
        G.load_state_dict(torch.load(sr_ckpt))
    D = Discriminator().to(device)
    optG = torch.optim.Adam(G.parameters(), lr=lr, betas=(0.9,0.999))
    optD = torch.optim.Adam(D.parameters(), lr=lr, betas=(0.9,0.999))
    mse = nn.MSELoss(); bce = nn.BCEWithLogitsLoss()

    for e in range(1, epochs+1):
        G.train(); D.train()
        for lr_imgs, hr_imgs in tqdm(dl, desc=f"Enh Ep {e}/{epochs}"):
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)
            valid = torch.ones(len(lr_imgs), 1, device=device)
            fake  = torch.zeros(len(lr_imgs), 1, device=device)

            # Discriminator step
            gen_hr_detached = G(lr_imgs).detach()
            lossD = (bce(D(hr_imgs), valid) + bce(D(gen_hr_detached), fake)) / 2
            optD.zero_grad(); lossD.backward(); optD.step()

            # Generator step
            gen_hr = G(lr_imgs)
            loss_content = mse(gen_hr, hr_imgs)
            loss_adv     = bce(D(gen_hr), valid)
            loss_pix     = mse(gen_hr, hr_imgs)
            lossG        = loss_content + 0.01 * loss_adv + 0.006 * loss_pix
            optG.zero_grad(); lossG.backward(); optG.step()

        # Show epoch progress
        print(f"Completed epoch {e}/{epochs}")

        if e % save_every == 0:
            torch.save(G.state_dict(), f'enh_gen_{e}.pth')
            torch.save(D.state_dict(), f'enh_disc_{e}.pth')

# ─── 7. Evaluation ────────────────────────────────────────────────────────
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim
import numpy as np

def evaluate(ckpt, hr_dir, lr_dir, enhanced=False):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    G = EnhancedGenerator().to(device) if enhanced else Generator().to(device)
    G.load_state_dict(torch.load(ckpt, map_location=device))
    G.eval()
    ds = DIV2KDataset(hr_dir, lr_dir)
    ps, ss = [], []
    for lr, hr in tqdm(DataLoader(ds,1), desc="Eval"):
        lr, hr = lr.to(device), hr.to(device)
        with torch.no_grad():
            out = G(lr)
        out_np = ((out.squeeze().permute(1,2,0).cpu().numpy()*255)).astype(np.uint8)
        hr_np  = ((hr.squeeze().permute(1,2,0).cpu().numpy()*255)).astype(np.uint8)
        ps.append(psnr(hr_np, out_np, data_range=255))
        ss.append(ssim(hr_np, out_np, multichannel=True, data_range=255))
    print(f'Average PSNR: {np.mean(ps):.4f}, Average SSIM: {np.mean(ss):.4f}')

# ─── 8. Run Training & Eval ───────────────────────────────────────────────
# Example usage:
train_srgan('DIV2K_train_HR','DIV2K_train_LR', epochs=30)
train_enhanced('DIV2K_train_HR','DIV2K_train_LR','gen_200.pth', epochs=30)
evaluate('enh_gen_100.pth','DIV2K_train_HR','DIV2K_train_LR', enhanced=True)

