In [None]:
import os, random, shutil, torch, cv2
import numpy as np
from tqdm import tqdm
from skimage.metrics import structural_similarity as ssim
from torchvision.transforms.functional import to_tensor
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torchvision.models import vgg16
from torchvision import transforms as T
import torch.nn.functional as F

In [None]:
source_dir = r"C:\Users\atole\OneDrive\Desktop\Python\archive(1)\Image Super Resolution - Unsplash\high res"
target_root = r"\Users\atole\OneDrive\Desktop\Python\Working Dataset"
os.makedirs(target_root, exist_ok=True)

In [None]:
import os, cv2, shutil, random
from tqdm import tqdm
from sklearn.model_selection import train_test_split


source_dir= r"C:\Users\atole\OneDrive\Desktop\Python\archive(1)\Image Super Resolution - Unsplash\high res"        
target_root= r"C:\Users\atole\OneDrive\Desktop\Python\Working Dataset"
valid_images= []

print(" Filtering images divisible by 8...")


for fname in tqdm(os.listdir(source_dir)):
    if fname.lower().endswith(('.jpg', '.png')):
        fpath = os.path.join(source_dir, fname)
        img = cv2.imread(fpath)
        if img is None:
            continue
        h, w = img.shape[:2]
        if h % 8 == 0 and w % 8 == 0:
            valid_images.append(fname)

print(f" Valid images: {len(valid_images)}")


random.seed(42)
train_imgs, test_imgs = train_test_split(valid_images, test_size=0.2, random_state=42)
train_A, train_B = train_test_split(train_imgs, test_size=0.5, random_state=42)
showcase_imgs = random.sample(test_imgs, min(50, len(test_imgs)))

def extract_patches(img, patch_size=256, stride=256):
    patches = []
    h, w = img.shape[:2]
    for y in range(0, h - patch_size + 1, stride):
        for x in range(0, w - patch_size + 1, stride):
            patch = img[y:y+patch_size, x:x+patch_size]
            patches.append(patch)
    return patches

def save_patches_from_list(file_list, folder_name, patch_size=256, stride=256):
    dest = os.path.join(target_root, folder_name)
    if os.path.exists(dest):
        shutil.rmtree(dest)
    os.makedirs(dest, exist_ok=True)

    patch_id = 0
    for fname in tqdm(file_list, desc=f" Creating patches for {folder_name}"):
        img = cv2.imread(os.path.join(source_dir, fname))
        if img is None:
            continue
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        h, w = h - (h % 8), w - (w % 8)
        img = cv2.resize(img, (w, h))
        patches = extract_patches(img, patch_size=patch_size, stride=stride)
        for patch in patches:
            out_path = os.path.join(dest, f"{patch_id:06}.png")
            cv2.imwrite(out_path, cv2.cvtColor(patch, cv2.COLOR_RGB2BGR))
            patch_id += 1

def copy_images(file_list, folder_name):
    dest = os.path.join(target_root, folder_name)
    if os.path.exists(dest):
        shutil.rmtree(dest)
    os.makedirs(dest, exist_ok=True)

    for fname in tqdm(file_list, desc=f" Copying full images to {folder_name}"):
        src_path = os.path.join(source_dir, fname)
        dst_path = os.path.join(dest, fname)
        shutil.copy(src_path, dst_path)
            
save_patches_from_list(train_A, "train_A", patch_size=256, stride=256)   
save_patches_from_list(train_B, "train_B", patch_size=256, stride=256)
copy_images(test_imgs, "test")
copy_images(showcase_imgs, "showcase")

print("\n Dataset Preparation Complete:")
print(f" Train A: {len(os.listdir(os.path.join(target_root, 'train_A')))} patches")
print(f" Train B: {len(os.listdir(os.path.join(target_root, 'train_B')))} patches")
print(f" Test: {len(os.listdir(os.path.join(target_root, 'test')))} full images")
print(f" Showcase: {len(os.listdir(os.path.join(target_root, 'showcase')))} full images")

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)

In [None]:
import requests
url = "https://raw.githubusercontent.com/swz30/Restormer/main/basicsr/models/archs/restormer_arch.py"
save_path = "restormer_arch.py"
response = requests.get(url)
if response.status_code == 200:
    with open(save_path, "wb") as f:
        f.write(response.content)
    print(f"Saved to {save_path}")
else:
    print(f" Failed to download. Status code: {response.status_code}")

In [None]:
import sys
!pip install einops
sys.path.append(r'C:\Users\atole\OneDrive\Desktop\Python\Working Dataset')
from restormer_arch import Restormer
import torch

teacher = Restormer(
    inp_channels=3,               
    out_channels=3,
    dim=48,
    num_blocks=[4, 6, 6, 8],
    num_refinement_blocks=4,
    heads=[1, 2, 4, 8],
    ffn_expansion_factor=2.66,
    bias=False,
    LayerNorm_type='WithBias',
    dual_pixel_task=False        
)


weights = torch.load(r"C:\Users\atole\OneDrive\Desktop\Python\single_image_defocus_deblurring.pth", map_location=device)
teacher.load_state_dict(weights.get("params", weights))
teacher = teacher.to(device).eval()

print(" Restormer (3-channel teacher for single-image defocus deblurring) ready")

In [None]:
class StudentCNN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = torch.nn.Sequential(
            torch.nn.Conv2d(3, 32, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
        )
        self.decoder = torch.nn.Sequential(
            torch.nn.Conv2d(64, 32, 3, padding=1), torch.nn.ReLU(),
            torch.nn.Conv2d(32, 3, 3, padding=1), torch.nn.Tanh()  
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return (x + 1) / 2

student = StudentCNN().to(device)
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
criterion = torch.nn.L1Loss()

In [None]:
vgg = vgg16(pretrained=True).features.to(device).eval()
for param in vgg.parameters():
    param.requires_grad = False

vgg_normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])

def perceptual_loss(pred, target):
    pred_resized = F.interpolate(pred, size=(224, 224), mode='bilinear', align_corners=False)
    target_resized = F.interpolate(target, size=(224, 224), mode='bilinear', align_corners=False)
    pred_norm = vgg_normalize(pred_resized.squeeze(0)).unsqueeze(0)
    target_norm = vgg_normalize(target_resized.squeeze(0)).unsqueeze(0)
    pred_feat = vgg(pred_norm)
    target_feat = vgg(target_norm)
    return F.l1_loss(pred_feat, target_feat)

In [None]:
train_path = os.path.join(target_root, "train_B")  
train_files = os.listdir(train_path)
student.train()
teacher.eval()
best_ssim = 0.0

for epoch in range(15):
    total_loss = 0
    student.train()
    ssim_total = 0.0

    for fname in tqdm(train_files, desc=f" Training Student | Epoch {epoch+1}"):
        img = cv2.imread(os.path.join(train_path, fname))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        h, w = img.shape[:2]
        h, w = h - (h % 8), w - (w % 8)
        gt = cv2.resize(img, (w, h))
        lr = cv2.resize(gt, (w//2, h//2), interpolation=cv2.INTER_LINEAR)
        input_img = cv2.resize(lr, (w, h), interpolation=cv2.INTER_CUBIC)

        input_tensor = to_tensor(input_img).unsqueeze(0).to(device)
        gt_tensor = to_tensor(gt).unsqueeze(0).to(device)

        with torch.no_grad():
            target_tensor = teacher(input_tensor)

        output = student(input_tensor)

        l1 = criterion(output, target_tensor)
        gt_l1 = criterion(output, gt_tensor)
        p_loss = perceptual_loss(output, gt_tensor)
        loss = 0.9 * l1 + 0.1 * gt_l1 + 0.005 * p_loss
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

        with torch.no_grad():
            s_output = output.squeeze(0).clamp(0, 1).cpu().numpy()
            s_gt = to_tensor(gt).numpy()
            s_output = np.transpose(s_output, (1, 2, 0))
            s_gt = np.transpose(s_gt, (1, 2, 0))
            score = ssim(s_output, s_gt, channel_axis=2, data_range=1.0, win_size=11)
            ssim_total += score

    avg_loss = total_loss / len(train_files)
    avg_ssim = ssim_total / len(train_files)
    print(f" Epoch {epoch+1} | Avg L1 Loss: {avg_loss:.6f} | Avg SSIM: {avg_ssim:.4f}")

if avg_ssim > best_ssim:
        best_ssim = avg_ssim
        torch.save(student.state_dict(), "best_student.pth")
        print(f" Best model saved at Epoch {epoch+1} with SSIM: {best_ssim:.4f}")