In [15]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import numpy as np
import cv2
from skimage.metrics import structural_similarity as ssim
#Student model
class lightweightstudentCNN(nn.Module):
  def __init__(self):
    super(lightweightstudentCNN,self).__init__()
    self.encoder=nn.Sequential(
        nn.Conv2d(3,16,kernel_size=3,padding=1),
        nn.ReLU(),
        nn.Conv2d(16,32,kernel_size=3,padding=1),
        nn.ReLU(),
        nn.MaxPool2d(2)
    )
    self.middle=nn.Sequential(
        nn.Conv2d(32,32,kernel_size=3,padding=1),
        nn.ReLU()
    )
    self.decoder=nn.Sequential(
        nn.Upsample(scale_factor=2,mode="bilinear",align_corners=False),
        nn.Conv2d(32,16,kernel_size=3,padding=1),
        nn.ReLU(),
        nn.Conv2d(16,3,kernel_size=3,padding=1),
        nn.Sigmoid()
    )
  def forward(self,x):
      x=self.encoder(x)
      x=self.middle(x)
      x=self.decoder(x)
      return x
#Custom Dataset
class ImagePairDataset(Dataset):
  def __init__(self,blurred,groundtruth,transform=None):
    self.blurred_paths=sorted([os.path.join(blurred,f) for f in os.listdir(blurred) if f.endswith(('.png','.jpg'))])
    self.groundtruth_paths=sorted([os.path.join(groundtruth,f) for f in os.listdir(groundtruth) if f.endswith(('.png','.jpg'))])
    self.transform=transform

  def __len__(self):
    return len(self.blurred_paths)

  def pad_to_multiple(self,image,m=16):
    width,height=image.size
    new_width=(width+m-1)//m*m
    new_height=(height+m-1)//m*m
    return image.resize((new_width,new_height))

  def __getitem__(self,idx):
    blur=Image.open(self.blurred_paths[idx]).convert("RGB")
    gt=Image.open(self.groundtruth_paths[idx]).convert("RGB")
    blur=self.pad_to_multiple(blur,m=16)
    gt=self.pad_to_multiple(gt,m=16)

    if self.transform:
      blur=self.transform(blur)
      gt=self.transform(gt)
    return blur,gt,os.path.basename(self.blurred_paths[idx])
   #Distillation training
def trainingCNN(blurred,groundtruth,teacher_model,epochs=10,batch_size=32,lr=1e-4):
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
    transform=transforms.Compose([
        transforms.Resize((256,480)),
        transforms.ToTensor()
        ])
    data=ImagePairDataset(blurred,groundtruth,transform=transform)
    dataloader=DataLoader(data,batch_size=batch_size,shuffle=True)

    teacher_model.to(device)
    teacher_model.eval()

    studentmodel=lightweightstudentCNN().to(device)
    optimizer=optim.Adam(studentmodel.parameters(),lr=lr)
    criteria=nn.MSELoss()

    for epoch in range(epochs):
      studentmodel.train()
      epoch_loss=0.0
      for blur,gt,_ in dataloader:
        blur,gt=blur.to(device),gt.to(device)
        with torch.no_grad():
          teacher_output=teacher_model(blur)
        student_output=studentmodel(blur)
        loss=distillation_loss(student_output,teacher_output,gt,alpha=0.5,beta=0.3,g=0.2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss=epoch_loss+loss.item()
        torch.cuda.empty_cache()
      print(f"Epoch {epoch+1}/{epochs}, Loss: {epoch_loss/len(dataloader)}")
    torch.save(studentmodel.state_dict(),"student_model.pth")


#Inference+SSIM Evaluation
def restore_and_evaluate(student_model,blurred,groundtruth):
  device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
  transform=transforms.Compose([
      transforms.Resize((1080,1920)),
      transforms.ToTensor()
  ])
  data=ImagePairDataset(blurred,groundtruth,transform=transform)
  dataloader=DataLoader(data,batch_size=1,shuffle=False)

  model=lightweightstudentCNN().to(device)
  model.load_state_dict(torch.load(student_model,map_location=device))
  model.eval()

  ssim_scores=[]
  os.makedirs("restored_images",exist_ok=True)

  for blur,gt,filename in dataloader:
    blur,gt=blur.to(device),gt.to(device)
    with torch.no_grad():
      output=model(blur)

    output_n=output.squeeze(0).permute(1,2,0).cpu().numpy()
    gt_n=gt.squeeze(0).permute(1,2,0).cpu().numpy()


    output_n = np.clip(output_n * 255, 0, 255).astype(np.uint8)
    gt_n = np.clip(gt_n * 255, 0, 255).astype(np.uint8)


    #save images
    output_image_path=os.path.join("restored_images",filename[0])
    Image.fromarray(output_n).save(output_image_path)

    #ssim
    score=ssim(output_n,gt_n,data_range=255,channel_axis=2)
    ssim_scores.append(score)
    print(f"SSIM for {filename[0]}:{score:.4f}")

  print(f"\nAverage SSIM:{np.mean(ssim_scores):.4f}")




In [4]:
import os
!pip install einops

if os.path.isdir('Restormer'):
  !rm -r Restormer

# Clone Restormer
!git clone https://github.com/swz30/Restormer.git
%cd Restormer

Cloning into 'Restormer'...
remote: Enumerating objects: 309, done.[K
remote: Counting objects: 100% (107/107), done.[K
remote: Compressing objects: 100% (51/51), done.[K
remote: Total 309 (delta 67), reused 56 (delta 56), pack-reused 202 (from 1)[K
Receiving objects: 100% (309/309), 1.56 MiB | 10.30 MiB/s, done.
Resolving deltas: 100% (123/123), done.
/content/Restormer


In [5]:
# task = 'Real_Denoising'
# task = 'Single_Image_Defocus_Deblurring'
task = 'Motion_Deblurring'
# task = 'Deraining'

# Download the pre-trained models
if task is 'Real_Denoising':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/real_denoising.pth -P Denoising/pretrained_models
if task is 'Single_Image_Defocus_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/single_image_defocus_deblurring.pth -P Defocus_Deblurring/pretrained_models
if task is 'Motion_Deblurring':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth -P Motion_Deblurring/pretrained_models
if task is 'Deraining':
  !wget https://github.com/swz30/Restormer/releases/download/v1.0/deraining.pth -P Deraining/pretrained_models


  if task is 'Real_Denoising':
  if task is 'Single_Image_Defocus_Deblurring':
  if task is 'Motion_Deblurring':


--2025-07-11 04:52:24--  https://github.com/swz30/Restormer/releases/download/v1.0/motion_deblurring.pth
Resolving github.com (github.com)... 140.82.112.3
Connecting to github.com (github.com)|140.82.112.3|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://objects.githubusercontent.com/github-production-release-asset-2e65be/418793252/55c7bcd2-cb39-4d8a-adc4-acf6f6131c27?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=releaseassetproduction%2F20250711%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20250711T045224Z&X-Amz-Expires=1800&X-Amz-Signature=a3db011f3f518b03a234694232aba66fc7d0e10a17c8a2c6e9504d60f2fabd7d&X-Amz-SignedHeaders=host&response-content-disposition=attachment%3B%20filename%3Dmotion_deblurring.pth&response-content-type=application%2Foctet-stream [following]
--2025-07-11 04:52:24--  https://objects.githubusercontent.com/github-production-release-asset-2e65be/418793252/55c7bcd2-cb39-4d8a-adc4-acf6f6131c27?X-Amz-Algorithm=AWS4-HMAC-SHA256&X

  if task is 'Deraining':


In [6]:
import os
import torch
from runpy import run_path

#Load Restormer Architecture
restormer_path=os.path.join("basicsr","models","archs","restormer_arch.py")
restormer_arch=run_path(restormer_path)
restormer=restormer_arch['Restormer']

#Define Model Parameters (match traininf config)
teacher_params={
    'inp_channels':3,
    'out_channels':3,
    'dim':48,
    'num_blocks':[4,6,6,8],
    'num_refinement_blocks':4,
    'heads':[1,2,4,8],
    'ffn_expansion_factor':2.66,
    'bias':False,
    'LayerNorm_type':'WithBias',
    'dual_pixel_task':False
}

#Load the pretrained model
restormer_model=restormer(**teacher_params)
ckpt_path=os.path.join("Motion_Deblurring","pretrained_models","motion_deblurring.pth")
checkpoint=torch.load(ckpt_path,map_location="cpu")
restormer_model.load_state_dict(checkpoint['params'])
restormer_model.eval()

Restormer(
  (patch_embed): OverlapPatchEmbed(
    (proj): Conv2d(3, 48, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  )
  (encoder_level1): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (attn): Attention(
        (qkv): Conv2d(48, 144, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (qkv_dwconv): Conv2d(144, 144, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=144, bias=False)
        (project_out): Conv2d(48, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (norm2): LayerNorm(
        (body): WithBias_LayerNorm()
      )
      (ffn): FeedForward(
        (project_in): Conv2d(48, 254, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (dwconv): Conv2d(254, 254, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=254, bias=False)
        (project_out): Conv2d(127, 48, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
    )
    (1): TransformerBlock(
 

In [1]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->pytorch_msssim)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->pytorch_msssim)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch->pytorch_msssim)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch->pytorch_mss

In [14]:
import torch
import torch.nn.functional as F
from pytorch_msssim import ssim as torch_ssim
from torchvision import models

vg=models.vgg16(pretrained=True).features.eval()
for param in vg.parameters():
  param.requires_grad=False

# Move VGG to the correct device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vg.to(device)

layers=[3,8,15]

def extract_features(a,model,layers):
  features=[]
  for i,layer in enumerate(model):
      a=layer(a)
      if i in layers:
        features.append(a)
  return features

def distillation_loss(student_out,teacher_out,ground_truth,alpha=0.5,beta=0.3,g=0.2):
  l1_loss=F.l1_loss(student_out,ground_truth)
  ssim_loss=1-torch_ssim(student_out,ground_truth,data_range=1.0)
  student_feats=extract_features(student_out,vg,layers)
  teacher_feats=extract_features(teacher_out,vg,layers)
  p_loss=sum( F.mse_loss(s,t) for s,t in zip(student_feats,teacher_feats))
  feature_loss=F.mse_loss(student_feats[-1],teacher_feats[-1])
  return(alpha*l1_loss+beta*ssim_loss+g*p_loss+0.1*feature_loss)



In [8]:
blur="/content/blurred_5200_0_6.zip"
gt="/content/numpysliced_5200.zip"

In [7]:
import zipfile
import os

#Extract ZIP files
def extract_zip(zip_path,extract_to):
  os.makedirs(extract_to,exist_ok=True)
  with zipfile.ZipFile(zip_path,"r") as zip_ref:
     zip_ref.extractall(extract_to)
  #return extract_to

blur_folder=extract_zip("/content/blurred_5200_0_6.zip","blur_extracted")
gt_folder=extract_zip("/content/numpysliced_5200.zip","gt_extracted")



In [9]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Ensure teacher model is on the correct device
teacher_model = restormer_model.to(device)  # ✅ This ensures compatibility with GPU input

# Later you already have:
studentmodel = lightweightstudentCNN().to(device)


In [16]:
trainingCNN("/content/Restormer/blur_extracted","/content/Restormer/gt_extracted",teacher_model=restormer_model,batch_size=16)

Epoch 1/10, Loss: 0.35737027493806983
Epoch 2/10, Loss: 0.16759135823983412
Epoch 3/10, Loss: 0.12452336146281316
Epoch 4/10, Loss: 0.09366513424194776
Epoch 5/10, Loss: 0.0864768362733034
Epoch 6/10, Loss: 0.08267088286005533
Epoch 7/10, Loss: 0.07986932713251847
Epoch 8/10, Loss: 0.07791924828520189
Epoch 9/10, Loss: 0.07624424122847044
Epoch 10/10, Loss: 0.07465646537450643


In [17]:
import zipfile
def extract_zip(zip_path, extract_to):
    os.makedirs(extract_to, exist_ok=True)
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(extract_to)
    return extract_to

In [18]:
x=extract_zip("/content/new_valid.zip","/content/gt1_valid")
y=extract_zip("/content/new_blur_valid.zip","/content/blur1_valid")

In [19]:
restore_and_evaluate("/content/Restormer/student_model.pth",y,x)

SSIM for 0801.png:0.9272
SSIM for 0802.png:0.8643
SSIM for 0803.png:0.9548
SSIM for 0804.png:0.8810
SSIM for 0805.png:0.9006
SSIM for 0806.png:0.8965
SSIM for 0807.png:0.7193
SSIM for 0808.png:0.8776
SSIM for 0809.png:0.9375
SSIM for 0810.png:0.8970
SSIM for 0811.png:0.9048
SSIM for 0812.png:0.8995
SSIM for 0813.png:0.9146
SSIM for 0814.png:0.9425
SSIM for 0815.png:0.9439
SSIM for 0816.png:0.9276
SSIM for 0817.png:0.9070
SSIM for 0818.png:0.9227
SSIM for 0819.png:0.9004
SSIM for 0820.png:0.8800
SSIM for 0821.png:0.8853
SSIM for 0822.png:0.9127
SSIM for 0823.png:0.8966
SSIM for 0824.png:0.9302
SSIM for 0825.png:0.9001
SSIM for 0826.png:0.8522
SSIM for 0827.png:0.9325
SSIM for 0828.png:0.7106
SSIM for 0829.png:0.8201
SSIM for 0830.png:0.8784
SSIM for 0831.png:0.9125
SSIM for 0832.png:0.9241
SSIM for 0833.png:0.9399
SSIM for 0834.png:0.8854
SSIM for 0835.png:0.8070
SSIM for 0836.png:0.8674
SSIM for 0837.png:0.8901
SSIM for 0838.png:0.9557
SSIM for 0839.png:0.9320
SSIM for 0840.png:0.9101
