# Import Libraries

In [None]:
import os, random, shutil, torch, cv2
import numpy as np
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from torchvision.transforms.functional import to_tensor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torchvision.models import vgg16
from torchvision import transforms as T
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Image Filtering & Processing

In [None]:
source_dir = r"C:\Users\atole\OneDrive\Desktop\Python\archive(1)\Image Super Resolution - Unsplash\high res"        
target_root = r"C:\Users\atole\OneDrive\Desktop\Python\Working Dataset"
os.makedirs(target_root, exist_ok=True)
valid_images = []

# This is to filter images to standard resolution i.e. in our case 1200*800
print(" Filtering images divisible by 8...")
for fname in tqdm(os.listdir(source_dir)): 
    if fname.lower().endswith(('.jpg', '.png')):
        fpath = os.path.join(source_dir, fname)
        img = cv2.imread(fpath)
        if img is None:
            continue
        h, w = img.shape[:2]
        if h % 8 == 0 and w % 8 == 0:
            valid_images.append(fname)
print(f" Valid images: {len(valid_images)}")

random.seed(42)
train_imgs, test_imgs = train_test_split(valid_images, test_size=0.2, random_state=42)
train_A, train_B = train_test_split(train_imgs, test_size=0.5, random_state=42)
showcase_imgs = random.sample(test_imgs, min(50, len(test_imgs)))

#For getting a wider variety of images for training we are breaking down the images into patches
def extract_patches(img, patch_size=256, stride=256):
    patches = []
    h, w = img.shape[:2]
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = img[y:y+patch_size, x:x+patch_size]
            patches.append(patch)
    return patches

def save_patches_from_list(file_list, folder_name, patch_size=256, stride=256):
    dest = os.path.join(target_root, folder_name)
    if os.path.exists(dest):
        shutil.rmtree(dest)
    os.makedirs(dest, exist_ok=True)

    patch_id = 0
    for fname in tqdm(file_list, desc=f" Creating patches for {folder_name}"):
        img = cv2.imread(os.path.join(source_dir, fname))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        h, w = h - (h % 8), w - (w % 8)
        img = cv2.resize(img, (w, h))
        patches = extract_patches(img, patch_size=patch_size, stride=stride)
        for patch in patches:
            out_path = os.path.join(dest, f"{patch_id:06}.png")
            cv2.imwrite(out_path, cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))
            patch_id += 1

def copy_images(file_list, folder_name):
    dest = os.path.join(target_root, folder_name)
    if os.path.exists(dest):
        shutil.rmtree(dest)
    os.makedirs(dest, exist_ok=True)

    for fname in tqdm(file_list, desc=f" Copying full images to {folder_name}"):
        src_path = os.path.join(source_dir, fname)
        dst_path = os.path.join(dest, fname)
        shutil.copy(src_path, dst_path)
            
save_patches_from_list(train_A, "train_A", patch_size=256, stride=256)   
save_patches_from_list(train_B, "train_B", patch_size=256, stride=256)
copy_images(test_imgs, "test")
copy_images(showcase_imgs, "showcase")

print("\n Dataset Preparation Complete:")
print(f" Train A: {len(os.listdir(os.path.join(target_root, 'train_A')))} patches")
print(f" Train B: {len(os.listdir(os.path.join(target_root, 'train_B')))} patches")
print(f" Test: {len(os.listdir(os.path.join(target_root, 'test')))} full images")
print(f" Showcase: {len(os.listdir(os.path.join(target_root, 'showcase')))} full images")

 Filtering images divisible by 8...


100%|██████████| 1254/1254 [00:09<00:00, 131.16it/s]


 Valid images: 989


 Creating patches for train_A: 100%|██████████| 395/395 [00:13<00:00, 28.27it/s]
 Creating patches for train_B: 100%|██████████| 396/396 [00:13<00:00, 28.33it/s]
 Copying full images to test: 100%|██████████| 198/198 [00:00<00:00, 1240.56it/s]
 Copying full images to showcase: 100%|██████████| 50/50 [00:00<00:00, 1218.35it/s]


 Dataset Preparation Complete:
 Train A: 4740 patches
 Train B: 4752 patches
 Test: 198 full images
 Showcase: 50 full images





# Importing restormer 

In [27]:
import requests
url = "https://raw.githubusercontent.com/swz30/Restormer/main/basicsr/models/archs/restormer_arch.py"
save_path = "restormer_arch.py"
response = requests.get(url)
if response.status_code == 200:
    with open(save_path, "wb") as f:
        f.write(response.content)
    print(f"Saved to {save_path}")
else:
    print(f" Failed to download. Status code: {response.status_code}")

Saved to restormer_arch.py


# Loading Restormer with its weights

In [28]:
import sys
!pip install einops
sys.path.append(r'C:\Users\atole\OneDrive\Desktop\Python\Working Dataset')
from restormer_arch import Restormer
import torch

teacher = Restormer(
    inp_channels=3,               
    out_channels=3,
    dim=48,
    num_blocks=[4, 6, 6, 8],
    num_refinement_blocks=4,
    heads=[1, 2, 4, 8],
    ffn_expansion_factor=2.66,
    bias=False,
    LayerNorm_type='WithBias',
    dual_pixel_task=False        
)


weights = torch.load(r"C:\Users\atole\OneDrive\Desktop\Python\single_image_defocus_deblurring.pth", map_location=device)
teacher.load_state_dict(weights.get("params", weights))
teacher = teacher.to(device).eval()

print(" Restormer (3-channel teacher for single-image defocus deblurring) ready")

 Restormer (3-channel teacher for single-image defocus deblurring) ready


# Defining the student model

In [None]:
class StudentCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Encoder 
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, padding=1), torch.nn.ReLU(), #RGB → 32 feature maps
            torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(), #32 → 64 feature maps
            torch.nn.MaxPool2d(2), # Downsample by 2 
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) # Upsample back to original size
        )
        #Decoder: To reconstruct the output
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(64, 32, 3, padding=1), torch.nn.ReLU(), # 64 → 32 feature maps
            torch.nn.Conv2d(32, 3, 3, padding=1), torch.nn.Tanh()  # 32 → RGB output, Output in range [-1, 1]
        )

    def forward(self, x):
        x = self.encoder(x) # Summarize: Feature extraction and size restoration
        x = self.decoder(x) # Decoding Image back to RGB
        return (x + 1) / 2  # Output in range [0, 1]

student = StudentCNN().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3) # Adam optimizer with learning rate 0.001
criterion = torch.nn.L1Loss() # Use L1 loss (pixel-wise difference)

# Perceptual Loss Definition 

In [30]:
vgg = vgg16(pretrained=True).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

vgg_normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])

def perceptual_loss(pred, target):
    pred_resized = F.interpolate(pred, size=(224, 224), mode='bilinear', align_corners=False)
    target_resized = F.interpolate(target, size=(224, 224), mode='bilinear', align_corners=False)
    pred_norm = vgg_normalize(pred_resized.squeeze(0)).unsqueeze(0)
    target_norm = vgg_normalize(target_resized.squeeze(0)).unsqueeze(0)
    pred_feat = vgg(pred_norm)
    target_feat = vgg(target_norm)
    return F.l1_loss(pred_feat, target_feat)



Downloading: "https://download.pytorch.org/models/vgg16-397923af.pth" to C:\Users\atole/.cache\torch\hub\checkpoints\vgg16-397923af.pth


100%|██████████| 528M/528M [03:16<00:00, 2.82MB/s] 


# Training Loop for Set A of images & Saving weights

In [None]:
train_path = os.path.join(target_root, "train_A")  
train_files = os.listdir(train_path)
student.train()
teacher.eval()
best_ssim = 0.0

for epoch in range(15):
    total_loss = 0
    student.train()
    ssim_total = 0.0

    for fname in tqdm(train_files, desc=f" Training Student | Epoch {epoch+1}"):
        img = cv2.imread(os.path.join(train_path, fname))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        h, w = h - (h % 8), w - (w % 8)
        gt = cv2.resize(img, (w, h))
        lr = cv2.resize(gt, (w//2, h//2), interpolation=cv2.INTER_LINEAR)
        input_img = cv2.resize(lr, (w, h), interpolation=cv2.INTER_CUBIC)

        input_tensor = to_tensor(input_img).unsqueeze(0).to(device)
        gt_tensor = to_tensor(gt).unsqueeze(0).to(device)

        with torch.no_grad():
            target_tensor = teacher(input_tensor)
        output = student(input_tensor)
        l1 = criterion(output, target_tensor)
        gt_l1 = criterion(output, gt_tensor)
        p_loss = perceptual_loss(output, gt_tensor)
        loss = 0.9 * l1 + 0.1 * gt_l1 + 0.005 * p_loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        with torch.no_grad():
            s_output = output.squeeze(0).clamp(0, 1).cpu().numpy()
            s_gt = to_tensor(gt).numpy()
            s_output = np.transpose(s_output, (1, 2, 0))
            s_gt = np.transpose(s_gt, (1, 2, 0))
            score = ssim(s_output, s_gt, channel_axis=2, data_range=1.0, win_size=11)
            ssim_total += score

    avg_loss = total_loss / len(train_files)
    avg_ssim = ssim_total / len(train_files)
    print(f" Epoch {epoch+1} | Avg L1 Loss: {avg_loss:.6f} | Avg SSIM: {avg_ssim:.4f}")

if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        torch.save(student.state_dict(), "best_student.pth")
        print(f" Best model saved at Epoch {epoch+1} with SSIM: {best_ssim:.4f}")

 Training Student | Epoch 1: 100%|██████████| 4740/4740 [28:35<00:00,  2.76it/s]


 Epoch 1 | Avg L1 Loss: 0.077644 | Avg SSIM: 0.7805


 Training Student | Epoch 2: 100%|██████████| 4740/4740 [28:52<00:00,  2.74it/s]


 Epoch 2 | Avg L1 Loss: 0.059260 | Avg SSIM: 0.8404


 Training Student | Epoch 3: 100%|██████████| 4740/4740 [28:47<00:00,  2.74it/s]


 Epoch 3 | Avg L1 Loss: 0.044978 | Avg SSIM: 0.8674


 Training Student | Epoch 4: 100%|██████████| 4740/4740 [28:18<00:00,  2.79it/s]


 Epoch 4 | Avg L1 Loss: 0.041758 | Avg SSIM: 0.8753


 Training Student | Epoch 5: 100%|██████████| 4740/4740 [28:39<00:00,  2.76it/s]


 Epoch 5 | Avg L1 Loss: 0.039660 | Avg SSIM: 0.8830


 Training Student | Epoch 6: 100%|██████████| 4740/4740 [28:30<00:00,  2.77it/s]


 Epoch 6 | Avg L1 Loss: 0.039505 | Avg SSIM: 0.8864


 Training Student | Epoch 7: 100%|██████████| 4740/4740 [28:42<00:00,  2.75it/s]


 Epoch 7 | Avg L1 Loss: 0.039845 | Avg SSIM: 0.8862


 Training Student | Epoch 8: 100%|██████████| 4740/4740 [28:39<00:00,  2.76it/s]


 Epoch 8 | Avg L1 Loss: 0.040055 | Avg SSIM: 0.8875


 Training Student | Epoch 9: 100%|██████████| 4740/4740 [28:01<00:00,  2.82it/s]


 Epoch 9 | Avg L1 Loss: 0.040332 | Avg SSIM: 0.8873


 Training Student | Epoch 10: 100%|██████████| 4740/4740 [28:12<00:00,  2.80it/s]


 Epoch 10 | Avg L1 Loss: 0.038801 | Avg SSIM: 0.8907


 Training Student | Epoch 11: 100%|██████████| 4740/4740 [28:10<00:00,  2.80it/s]


 Epoch 11 | Avg L1 Loss: 0.038823 | Avg SSIM: 0.8892


 Training Student | Epoch 12: 100%|██████████| 4740/4740 [28:07<00:00,  2.81it/s]


 Epoch 12 | Avg L1 Loss: 0.039860 | Avg SSIM: 0.8884


 Training Student | Epoch 13: 100%|██████████| 4740/4740 [29:05<00:00,  2.71it/s]


 Epoch 13 | Avg L1 Loss: 0.038409 | Avg SSIM: 0.8895


 Training Student | Epoch 14: 100%|██████████| 4740/4740 [28:12<00:00,  2.80it/s]


 Epoch 14 | Avg L1 Loss: 0.038985 | Avg SSIM: 0.8930


 Training Student | Epoch 15: 100%|██████████| 4740/4740 [28:20<00:00,  2.79it/s]

 Epoch 15 | Avg L1 Loss: 0.038613 | Avg SSIM: 0.8933
 Best model saved at Epoch 15 with SSIM: 0.8933



