In [1]:
'''
!pip install opencv-python

import os
import cv2

input_dir = 'dataset/DIV2K_train_HR'
blurred_dir = 'dataset/DIV2K_train_LR'

os.makedirs(blurred_dir, exist_ok=True)

for img_name in os.listdir(input_dir):
    img_path = os.path.join(input_dir, img_name)
    img = cv2.imread(img_path)

    if img is None:
        print(f"Error reading {img_name}")
        continue

    # Use a larger kernel size and sigma for stronger blur
    blurred = cv2.GaussianBlur(img, (11, 11), sigmaX=5.0)

    cv2.imwrite(os.path.join(blurred_dir, img_name), blurred)
'''
!pip install opencv-python pillow

import os
import cv2
from PIL import Image
import numpy as np

input_dir = 'dataset/DIV2K_train_HR'
blurred_dir = 'dataset/DIV2K_train_LR'

os.makedirs(blurred_dir, exist_ok=True)

for img_name in os.listdir(input_dir):
    img_path = os.path.join(input_dir, img_name)

    try:
        # Read image using PIL (reliable inside Anaconda)
        img_pil = Image.open(img_path).convert('RGB')
        img = np.array(img_pil)                    # PIL to numpy array
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Convert RGB to BGR for OpenCV

        blurred = cv2.GaussianBlur(img, (11, 11), sigmaX=5.0)

        save_path = os.path.join(blurred_dir, img_name)
        cv2.imwrite(save_path, blurred)
    except Exception as e:
        print(f"Error processing {img_name}: {e}")




In [2]:
!pip install torch torchvision

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import os

class DIV2KSharpenDataset(Dataset):
    def __init__(self, input_dir, target_dir, transform=None):
        self.input_dir = input_dir
        self.target_dir = target_dir
        self.transform = transform
        self.filenames = sorted(os.listdir(input_dir))

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        input_img = Image.open(os.path.join(self.input_dir, self.filenames[idx])).convert('RGB')
        target_img = Image.open(os.path.join(self.target_dir, self.filenames[idx])).convert('RGB')

        if self.transform:
            input_img = self.transform(input_img)
            target_img = self.transform(target_img)

        return input_img, target_img

transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor()
])

train_dataset = DIV2KSharpenDataset('dataset/DIV2K_train_LR', 'dataset/DIV2K_train_HR', transform)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)




In [3]:
import torch
import torch.nn as nn

class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class UNetTeacher(nn.Module):
    def __init__(self, in_channels=3, out_channels=3, base_features=64):
        super(UNetTeacher, self).__init__()

        self.enc1 = ConvBlock(in_channels, base_features)
        self.pool1 = nn.MaxPool2d(2)
        self.enc2 = ConvBlock(base_features, base_features * 2)
        self.pool2 = nn.MaxPool2d(2)
        self.enc3 = ConvBlock(base_features * 2, base_features * 4)
        self.pool3 = nn.MaxPool2d(2)
        self.enc4 = ConvBlock(base_features * 4, base_features * 8)
        self.pool4 = nn.MaxPool2d(2)

        self.bottleneck = ConvBlock(base_features * 8, base_features * 16)

        self.upconv4 = nn.ConvTranspose2d(base_features * 16, base_features * 8, kernel_size=2, stride=2)
        self.dec4 = ConvBlock(base_features * 16, base_features * 8)

        self.upconv3 = nn.ConvTranspose2d(base_features * 8, base_features * 4, kernel_size=2, stride=2)
        self.dec3 = ConvBlock(base_features * 8, base_features * 4)

        self.upconv2 = nn.ConvTranspose2d(base_features * 4, base_features * 2, kernel_size=2, stride=2)
        self.dec2 = ConvBlock(base_features * 4, base_features * 2)

        self.upconv1 = nn.ConvTranspose2d(base_features * 2, base_features, kernel_size=2, stride=2)
        self.dec1 = ConvBlock(base_features * 2, base_features)

        self.output_conv = nn.Conv2d(base_features, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        e4 = self.enc4(self.pool3(e3))

        b = self.bottleneck(self.pool4(e4))

        d4 = self.upconv4(b)
        d4 = torch.cat((d4, e4), dim=1)
        d4 = self.dec4(d4)

        d3 = self.upconv3(d4)
        d3 = torch.cat((d3, e3), dim=1)
        d3 = self.dec3(d3)

        d2 = self.upconv2(d3)
        d2 = torch.cat((d2, e2), dim=1)
        d2 = self.dec2(d2)

        d1 = self.upconv1(d2)
        d1 = torch.cat((d1, e1), dim=1)
        d1 = self.dec1(d1)

        return self.output_conv(d1)


In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = UNetTeacher().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

num_epochs = 15

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for input_img, target_img in tqdm(train_loader):
        input_img = input_img.to(device)
        target_img = target_img.to(device)

        output = model(input_img)
        loss = criterion(output, target_img)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}")


100%|██████████| 100/100 [25:44<00:00, 15.44s/it]


Epoch [1/15], Loss: 0.0772


100%|██████████| 100/100 [25:46<00:00, 15.46s/it]


Epoch [2/15], Loss: 0.0063


100%|██████████| 100/100 [26:35<00:00, 15.96s/it]


Epoch [3/15], Loss: 0.0049


100%|██████████| 100/100 [24:09<00:00, 14.50s/it]


Epoch [4/15], Loss: 0.0047


100%|██████████| 100/100 [22:29<00:00, 13.49s/it]


Epoch [5/15], Loss: 0.0041


100%|██████████| 100/100 [21:19<00:00, 12.79s/it]


Epoch [6/15], Loss: 0.0039


100%|██████████| 100/100 [21:09<00:00, 12.70s/it]


Epoch [7/15], Loss: 0.0041


100%|██████████| 100/100 [21:01<00:00, 12.62s/it]


Epoch [8/15], Loss: 0.0031


100%|██████████| 100/100 [21:03<00:00, 12.63s/it]


Epoch [9/15], Loss: 0.0032


100%|██████████| 100/100 [29:10<00:00, 17.50s/it]


Epoch [10/15], Loss: 0.0031


100%|██████████| 100/100 [3:38:30<00:00, 131.10s/it]   


Epoch [11/15], Loss: 0.0028


100%|██████████| 100/100 [20:56<00:00, 12.56s/it]


Epoch [12/15], Loss: 0.0029


100%|██████████| 100/100 [20:58<00:00, 12.58s/it]


Epoch [13/15], Loss: 0.0026


100%|██████████| 100/100 [21:17<00:00, 12.77s/it]


Epoch [14/15], Loss: 0.0029


100%|██████████| 100/100 [21:08<00:00, 12.69s/it]

Epoch [15/15], Loss: 0.0028





In [5]:
torch.save(model.state_dict(), 'teacher_unet.pth')


In [6]:
#student model

In [7]:
import torch.nn as nn

class StudentCNN(nn.Module):
    def __init__(self):
        super(StudentCNN, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 16, 3, padding=1),  # Output: 16x256x256
            nn.ReLU(),
            nn.Conv2d(16, 32, 3, padding=1),  # 32x256x256
            nn.ReLU(),
            nn.Conv2d(32, 16, 3, padding=1),  # 16x256x256
            nn.ReLU(),
            nn.Conv2d(16, 3, 3, padding=1)   # Back to 3 channels
        )

    def forward(self, x):
        return self.model(x)


In [8]:
teacher = UNetTeacher()
teacher.load_state_dict(torch.load('teacher_unet.pth'))
teacher.eval()  # Freeze weights
for param in teacher.parameters():
    param.requires_grad = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
teacher = teacher.to(device)

In [9]:
import torch.nn.functional as F

def distillation_loss(student_output, teacher_output, target, alpha=0.7):
    l1_target = F.l1_loss(student_output, target)
    l1_teacher = F.l1_loss(student_output, teacher_output.detach())
    return alpha * l1_target + (1 - alpha) * l1_teacher


In [10]:
from tqdm import tqdm

student = StudentCNN().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-4)
num_epochs = 20

for epoch in range(num_epochs):
    student.train()
    running_loss = 0.0

    loop = tqdm(train_loader, desc="Epoch [{}/{}]".format(epoch+1, num_epochs))

    for inputs, targets in loop:
        inputs, targets = inputs.to(device), targets.to(device)

        with torch.no_grad():
            teacher_outputs = teacher(inputs)

        teacher_outputs = teacher_outputs.to(device)  # Make sure it’s on CUDA if available
        student_outputs = student(inputs)

        loss = distillation_loss(student_outputs, teacher_outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        loop.set_postfix(loss=running_loss / (loop.n + 1))

    print("Epoch [{}/{}] Loss: {:.4f}".format(epoch+1, num_epochs, running_loss / len(train_loader)))




Epoch [1/20]: 100%|██████████| 100/100 [07:56<00:00,  4.76s/it, loss=0.248]


Epoch [1/20] Loss: 0.2475


Epoch [2/20]: 100%|██████████| 100/100 [08:14<00:00,  4.94s/it, loss=0.0807]


Epoch [2/20] Loss: 0.0807


Epoch [3/20]: 100%|██████████| 100/100 [08:13<00:00,  4.94s/it, loss=0.0503]


Epoch [3/20] Loss: 0.0503


Epoch [4/20]: 100%|██████████| 100/100 [08:15<00:00,  4.96s/it, loss=0.0412]


Epoch [4/20] Loss: 0.0412


Epoch [5/20]: 100%|██████████| 100/100 [08:14<00:00,  4.94s/it, loss=0.0361]


Epoch [5/20] Loss: 0.0361


Epoch [6/20]: 100%|██████████| 100/100 [08:15<00:00,  4.95s/it, loss=0.0325]


Epoch [6/20] Loss: 0.0325


Epoch [7/20]: 100%|██████████| 100/100 [08:18<00:00,  4.99s/it, loss=0.03] 


Epoch [7/20] Loss: 0.0300


Epoch [8/20]: 100%|██████████| 100/100 [08:13<00:00,  4.94s/it, loss=0.0284]


Epoch [8/20] Loss: 0.0284


Epoch [9/20]: 100%|██████████| 100/100 [08:17<00:00,  4.97s/it, loss=0.0276]


Epoch [9/20] Loss: 0.0276


Epoch [10/20]: 100%|██████████| 100/100 [08:16<00:00,  4.96s/it, loss=0.0265]


Epoch [10/20] Loss: 0.0265


Epoch [11/20]: 100%|██████████| 100/100 [08:13<00:00,  4.93s/it, loss=0.0254]


Epoch [11/20] Loss: 0.0254


Epoch [12/20]: 100%|██████████| 100/100 [08:16<00:00,  4.96s/it, loss=0.0245]


Epoch [12/20] Loss: 0.0245


Epoch [13/20]: 100%|██████████| 100/100 [08:17<00:00,  4.98s/it, loss=0.0237]


Epoch [13/20] Loss: 0.0237


Epoch [14/20]: 100%|██████████| 100/100 [08:15<00:00,  4.95s/it, loss=0.0234]


Epoch [14/20] Loss: 0.0234


Epoch [15/20]: 100%|██████████| 100/100 [08:15<00:00,  4.96s/it, loss=0.0227]


Epoch [15/20] Loss: 0.0227


Epoch [16/20]: 100%|██████████| 100/100 [08:21<00:00,  5.01s/it, loss=0.0221]


Epoch [16/20] Loss: 0.0221


Epoch [17/20]: 100%|██████████| 100/100 [08:23<00:00,  5.03s/it, loss=0.0217]


Epoch [17/20] Loss: 0.0217


Epoch [18/20]: 100%|██████████| 100/100 [08:20<00:00,  5.01s/it, loss=0.0214]


Epoch [18/20] Loss: 0.0214


Epoch [19/20]: 100%|██████████| 100/100 [08:26<00:00,  5.06s/it, loss=0.0212]


Epoch [19/20] Loss: 0.0212


Epoch [20/20]: 100%|██████████| 100/100 [08:23<00:00,  5.04s/it, loss=0.021]

Epoch [20/20] Loss: 0.0210





In [12]:
from skimage.metrics import structural_similarity as ssim
import numpy as np
import matplotlib.pyplot as plt

student.eval()
total_ssim = 0
num_samples = 0

with torch.no_grad():
    for i, (inputs, targets) in enumerate(train_loader):
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = student(inputs)

        for j in range(outputs.size(0)):
            pred = outputs[j].cpu().permute(1, 2, 0).numpy()
            true = targets[j].cpu().permute(1, 2, 0).numpy()

            pred = np.clip(pred, 0, 1)
            true = np.clip(true, 0, 1)

            ssim_score = ssim(pred, true, data_range=1.0, channel_axis=-1)
            total_ssim += ssim_score
            num_samples += 1

        if i == 2:  # Limit for quick preview
            break

avg_ssim = total_ssim / num_samples
print("Average SSIM of student model: {:.4f}".format(avg_ssim))


Average SSIM of student model: 0.9566
