# **Image Sharpeing using Knowledge Distillation**

# First Steps

Uninstalling Libraries

In [None]:
!pip uninstall -y torch torchvision realesrgan piq basicsr

Installing Libraries

In [None]:
!pip install torch==2.0.1 torchvision==0.15.2 basicsr==1.4.2 numpy==1.26.4 --force-reinstall --no-cache-dir

Checking Versions

In [None]:
import torch
import torchvision
import numpy as np

print("Torch:", torch.__version__)
print("Torchvision:", torchvision.__version__)
print("Numpy:", np.__version__)

Connecting to CUDA

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"✅ Using device: {device}")

Cloning Real-ESRGAN Model

In [None]:
!git clone https://github.com/xinntao/Real-ESRGAN.git
%cd Real-ESRGAN
!pip install facexlib gfpgan
!pip install -r requirements.txt
!python setup.py develop
!pip install piq

Importing Libraries

In [None]:
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.transforms import ToTensor
from PIL import Image
import os, cv2
from tqdm import tqdm
import matplotlib.pyplot as plt
import torchvision.models as models
import torch.nn.functional as F

Perceptual Loss

In [None]:
# ✅ Load VGG19 for perceptual loss
vgg = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features[:16].to(device).eval()

for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(x, y):
    return F.l1_loss(vgg(x), vgg(y))

Mounting Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

# Dataset

Installing Dataset

In [None]:
!mkdir -p /content/DIV2K
!wget -c https://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip -P /content/DIV2K
!unzip -q /content/DIV2K/DIV2K_train_HR.zip -d /content/

Prepare Dataset

In [None]:
sharp_dir = "/content/DIV2K_train_HR"
degraded_dir = "/content/drive/MyDrive/intel/DIV2K/degraded"
os.makedirs(degraded_dir, exist_ok=True)

Degrade Images

In [None]:
def degrade_image(img):
    img = cv2.resize(img, (128, 128), interpolation=cv2.INTER_CUBIC)
    img = cv2.resize(img, (256, 256), interpolation=cv2.INTER_CUBIC)
    noise = np.random.normal(0, 2, img.shape)
    img = np.clip(img + noise, 0, 255).astype(np.uint8)
    return img

for fname in tqdm(os.listdir(sharp_dir)):
    img_path = os.path.join(sharp_dir, fname)
    img = cv2.imread(img_path)
    if img is None:
        continue
    degraded = degrade_image(img)
    cv2.imwrite(os.path.join(degraded_dir, fname), degraded)

In [None]:
# Get only filenames that exist in both folders
sharp_names = set(os.listdir(sharp_dir))
degraded_names = set(os.listdir(degraded_dir))
common_names = sorted(list(sharp_names & degraded_names))

print(f"✅ Matched files: {len(common_names)} found in BOTH folders")

Sharpen Images

In [None]:
class SharpeningDataset(Dataset):
    def __init__(self, degraded_dir, sharp_dir, image_names, transform=None):
        self.degraded_dir = degraded_dir
        self.sharp_dir = sharp_dir
        self.image_names = image_names
        self.transform = transform or T.ToTensor()

    def __len__(self):
        return len(self.image_names)

    def __getitem__(self, idx):
        name = self.image_names[idx]
        degraded = Image.open(os.path.join(self.degraded_dir, name)).convert("RGB")
        sharp = Image.open(os.path.join(self.sharp_dir, name)).convert("RGB")
        degraded = self.transform(degraded)
        sharp = self.transform(sharp)
        return degraded, sharp

Transform the Pipeline

In [None]:
transform = T.Compose([
    T.RandomCrop(128),
    T.ToTensor()
])

Loading Data

In [None]:
all_imgs = os.listdir(degraded_dir)
train_set = SharpeningDataset(degraded_dir, sharp_dir, all_imgs, transform)
train_loader = DataLoader(train_set, batch_size=4, shuffle=True)

print(f"✅ Degraded images: {len(all_imgs)} ready for training.")

# Teacher Model

Download Anime6B Model

In [None]:
!wget -q https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth -P experiments/pretrained_models

Import Necessities

In [None]:
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer

Define Model

In [None]:
model = RRDBNet(
    num_in_ch=3,
    num_out_ch=3,
    num_feat=64,
    num_block=6,     # Anime 6B uses 6 RRDB blocks
    num_grow_ch=32
)

Load Model

In [None]:
# ✅ Load the RealESRGANer with RRDBNet
teacher = RealESRGANer(
    scale=4,
    model_path='experiments/pretrained_models/RealESRGAN_x4plus_anime_6B.pth',
    model=model,
    tile=0,
    tile_pad=10,
    pre_pad=0,
    half=True  # Use FP16 if supported
)

print("✅ RealESRGAN Anime 6B Teacher loaded.")

Freeze Teacher

In [None]:
teacher.model.eval()
for p in teacher.model.parameters():
    p.requires_grad = False

print("✅ Teacher model frozen & ready.")

# Student Model

Student Model

In [None]:
class StudentCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.body = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 48, 3, padding=1),  # to prep for pixel shuffle
            nn.PixelShuffle(2),               # upscale x2
            nn.Conv2d(12, 48, 3, padding=1),  # channels = 12 after shuffle
            nn.PixelShuffle(2),               # upscale x2 again → total x4
            nn.Conv2d(12, 3, 3, padding=1)    # final RGB output
        )

    def forward(self, x):
        return self.body(x)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
student = StudentCNN().to(device)
print("✅ Student model created.")

Loading existing weights

In [None]:
student_path = "/content/drive/MyDrive/intel/student_model_1.pth"

if os.path.exists(student_path):
    try:
        student.load_state_dict(torch.load(student_path), strict=False)
        print(f"✅ Loaded existing student weights from {student_path} (non-strict)")
    except Exception as e:
        print(f"❌ Error loading state dict: {e}")
        print("ℹ️ Starting training with fresh weights.")
else:
    print("ℹ️ No previous student weights found. Starting fresh.")

Training

In [None]:
import torch.nn.functional as F
from torchvision.models import vgg16
from torchvision.models.feature_extraction import create_feature_extractor

# ✅ Setup your VGG perceptual extractor ONCE outside the loop
vgg = vgg16(pretrained=True).features[:16].eval().to(device)
for param in vgg.parameters():
    param.requires_grad = False

def perceptual_loss(pred, target):
    pred_features = vgg(pred)
    target_features = vgg(target)
    return F.l1_loss(pred_features, target_features)

criterion = nn.L1Loss()
optimizer = optim.Adam(student.parameters(), lr=1e-4)

num_epochs = 50
lambda_perc = 0.1  # <-- you can tune this

for epoch in range(num_epochs):
    student.train()
    running_loss = 0

    for degraded_imgs, _ in train_loader:
        degraded_imgs = degraded_imgs.to(device)

        teacher_out = []
        for i in range(degraded_imgs.size(0)):
            img_np = degraded_imgs[i].cpu().permute(1, 2, 0).numpy() * 255
            img_np = img_np.astype(np.uint8)
            img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

            output, _ = teacher.enhance(img_bgr)
            assert isinstance(output, np.ndarray), f"Teacher enhance output is not ndarray but {type(output)}"
            output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
            teacher_out.append(T.ToTensor()(output_rgb))

        teacher_out = torch.stack(teacher_out).to(device)

        student_out = student(degraded_imgs)

        loss_l1 = criterion(student_out, teacher_out)
        loss_perc = perceptual_loss(student_out, teacher_out)
        total_loss = loss_l1 + lambda_perc * loss_perc

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        running_loss += total_loss.item()

    avg_loss = running_loss / len(train_loader)
    print(f"Epoch {epoch+1}/{num_epochs} | L1: {loss_l1:.4f} | Perc: {loss_perc:.4f} | Total: {total_loss:.4f} | Avg: {avg_loss:.4f}")

Saving Model

In [None]:
student_save_path = "/content/drive/MyDrive/intel/student_model_1.pth"
torch.save(student.state_dict(), student_save_path)
print(f"✅ Student saved to {student_save_path}")

Checking Output Shape

In [None]:
print("Teacher out:", teacher_out.shape)
print("Student out:", student_out.shape)
print("Degraded Input Tensor (last batch):", degraded_imgs.shape)
assert teacher_out.shape == student_out.shape, "Teacher and student output shapes do not match!"

# Mapping

Evaluation Mode

In [None]:
student.eval()
degraded, sharp = next(iter(train_loader))
degraded = degraded.to(device)
sharp = sharp.to(device)

print(f"Degraded shape: {degraded.shape} | Sharp shape: {sharp.shape}")

Run Student & Unsample GT

In [None]:
with torch.no_grad():
    student_out = student(degraded)
    gt_up = F.interpolate(
        sharp,
        size=(student_out.size(2), student_out.size(3)),
        mode='bicubic',
        align_corners=False
    )

print(f"Student Output shape: {student_out.shape}")
print(f"GT Upsampled shape: {gt_up.shape}")

Visualize

In [None]:
degraded_img = degraded[0].detach().cpu().permute(1,2,0).numpy().clip(0,1)
student_img = student_out[0].detach().cpu().permute(1,2,0).numpy().clip(0,1)
sharp_img = gt_up[0].detach().cpu().permute(1,2,0).numpy().clip(0,1)

fig, axes = plt.subplots(1, 3, figsize=(15,5))
axes[0].imshow(degraded_img)
axes[0].set_title("Degraded Input")
axes[1].imshow(student_img)
axes[1].set_title("Student Output")
axes[2].imshow(sharp_img)
axes[2].set_title("Ground Truth (Upsampled)")

for ax in axes:
    ax.axis('off')

plt.show()

# Evaluation

Importing Essentials

In [None]:
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import matplotlib.pyplot as plt

Student Eval

In [None]:
student.eval()

degraded, sharp = next(iter(train_loader))
degraded, sharp = degraded.to(device), sharp.to(device)

with torch.no_grad():
    student_out = student(degraded)

    gt_up = F.interpolate(
        sharp,
        size=(student_out.size(2), student_out.size(3)),
        mode='bicubic',
        align_corners=False
    )

In [None]:
# 👉 Convert first sample to numpy
student_np = student_out[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)
gt_np = gt_up[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)

# ✅ SSIM
ssim_score = ssim(student_np, gt_np, channel_axis=2, data_range=1.0)
print(f"✅ Student SSIM: {ssim_score:.4f}")

# ✅ PSNR
psnr_val = psnr(gt_np, student_np, data_range=1.0)
print(f"✅ Student PSNR: {psnr_val:.2f} dB")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(student_np)
axes[0].set_title(f"Student Output\nSSIM: {ssim_score:.4f}  PSNR: {psnr_val:.2f} dB")
axes[1].imshow(gt_np)
axes[1].set_title("Ground Truth (Upsampled)")
for ax in axes:
    ax.axis('off')
plt.show()

Teacher Eval

In [None]:
teacher.model.eval()

teacher_out_list = []
for i in range(degraded.size(0)):
    img_np = degraded[i].cpu().permute(1, 2, 0).numpy() * 255
    img_np = img_np.astype(np.uint8)
    img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)

    output, _ = teacher.enhance(img_bgr)
    output_rgb = cv2.cvtColor(output, cv2.COLOR_BGR2RGB)
    teacher_out_list.append(T.ToTensor()(output_rgb))

teacher_out = torch.stack(teacher_out_list).to(device)

In [None]:
gt_up_teacher = F.interpolate(
    sharp,
    size=(teacher_out.size(2), teacher_out.size(3)),
    mode='bicubic',
    align_corners=False
)

In [None]:
teacher_np = teacher_out[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)
gt_teacher_np = gt_up_teacher[0].cpu().permute(1, 2, 0).numpy().clip(0, 1)

ssim_score_teacher = ssim(teacher_np, gt_teacher_np, channel_axis=2, data_range=1.0)
psnr_teacher = psnr(gt_teacher_np, teacher_np, data_range=1.0)

print(f"✅ Teacher SSIM: {ssim_score_teacher:.4f}")
print(f"✅ Teacher PSNR: {psnr_teacher:.2f} dB")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].imshow(teacher_np)
axes[0].set_title(f"Teacher Output\nSSIM: {ssim_score_teacher:.4f}  PSNR: {psnr_teacher:.2f} dB")
axes[1].imshow(gt_teacher_np)
axes[1].set_title("Ground Truth (Upsampled)")
for ax in axes:
    ax.axis('off')
plt.show()