In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# MODEL ARCHITECTURE

In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class SupervisedAttentionModule(nn.Module):
    def __init__(self,num_features ):
        super(SupervisedAttentionModule, self).__init__()
        self.conv1 = nn.Conv2d(num_features, 3, 1)
        self.conv2 = nn.Conv2d(num_features, num_features, 1)
        self.conv3 = nn.Conv2d(3, num_features, 1)
    def forward(self, f_in, degraded_1):
        r_s = self.conv1(f_in)
        x_s = r_s + degraded_1
        f_out = self.conv2(f_in)
        m = torch.sigmoid(self.conv3(x_s))
        f_out = f_out * m
        f_out = f_out + f_in
        return f_out, x_s
    
class ChannelAttentionBlock(nn.Module):
    def __init__(self, num_features, kernel, reduction):
        super(ChannelAttentionBlock, self).__init__()
        self.prelu = nn.PReLU()
        self.conv0 = nn.Conv2d(num_features, num_features, kernel, padding = 'same')
        self.conv1 = nn.Conv2d(num_features, num_features, kernel, padding = 'same')
        self.conv2 = nn.Conv2d(num_features, num_features//reduction, 1)
        self.conv3 = nn.Conv2d(num_features//reduction, num_features, 1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()
        self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
    
    def forward(self, x):
        x1 = self.conv0(x)
        #print(x1.shape)
        x1 = self.prelu(x1)
        x1 = self.conv1(x1)
        #print(x1.shape)
        x2 = self.global_avg_pooling(x1)
        #print(x2.shape)
        x2 = self.conv2(x2)
        #print(x2.shape)
        x2 = self.prelu(x2)
        x2 = self.conv3(x2)
        #print(x2.shape)
        x2 = self.sigmoid(x2)
        x2 = x2*x1
        #print(x2.shape)
        #print(x.shape)
        out = x2+x
        return out
        
        

class OriginalResolutionBlock(nn.Module):
    def __init__(self, kernel, reduction, num_cabs = 8, num_features = 96):
        super(OriginalResolutionBlock, self).__init__()
        self.cab_list = []
        for i in range(num_cabs):
            self.cab_list.append(ChannelAttentionBlock(num_features,kernel,reduction))
        self.cab_list.append(nn.Conv2d(num_features,num_features,kernel, padding = 'same'))
        self.orb = nn.Sequential(*self.cab_list)
    
    def forward(self,x):
        out = self.orb(x)
        out += x
        return out
        
        
    

class OriginalResolutionSubNetwork(nn.Module):
    def __init__(self, num_features, kernel_size, reduction, num_cabs):
        super(OriginalResolutionSubNetwork, self).__init__()
        self.orb1 = OriginalResolutionBlock(kernel_size, reduction, num_cabs, num_features)
        self.orb2 = OriginalResolutionBlock(kernel_size, reduction, num_cabs, num_features)
        self.orb3 = OriginalResolutionBlock(kernel_size, reduction, num_cabs, num_features)
        
        
        self.csffenc1 = nn.Conv2d(num_features, num_features, 1)
        self.csffenc2 = nn.Conv2d(num_features + (num_features//reduction), num_features, 1)
        self.csffenc3 = nn.Conv2d(num_features + 2*(num_features//reduction), num_features, 1)
         
        self.csffdec1 = nn.Conv2d(num_features, num_features, 1)
        self.csffdec2= nn.Conv2d(num_features, num_features, 1)
        self.csffdec3= nn.Conv2d(num_features+ (num_features//reduction), num_features, 1)
        
        self.up1 = nn.Upsample(scale_factor= 2, mode = 'bilinear', align_corners=True)
        self.up2 = nn.Upsample(scale_factor = 4, mode = 'bilinear', align_corners = True)
        
    def forward(self, x , e1, d1, e2, d2, e3, d3):
        
        #print(x.shape, e1.shape, d1.shape, e2.shape, d2.shape, e3.shape, d3.shape)
        
        x = self.orb1(x)
        #print(f"Shape of e1 is {self.csffenc1(e1).shape}, Shape of d1 is {self.csffdec1(d1).shape}, Shape of x is {x.shape}")
        x = x + self.csffenc1(e1) + self.csffdec1(d1)
        x = self.orb2(x)
        #print(f"Shape of e2 is {self.csffenc2(e2).shape}, Shape of d2 is {self.csffdec2(d2).shape}, Shape of x is {x.shape}")
        x = x + self.up1(self.csffenc2(e2)) + self.up1(self.csffdec2(d2))
        x = self.orb3(x)
        #print(f"Shape of e3 is {self.csffenc3(e3).shape}, Shape of d3 is {self.csffdec3(d3).shape}, Shape of x is {x.shape}")
        x = x + self.up2(self.csffenc3(e3)) + self.up2(self.csffdec3(d3))
        return x

same = 'same'
class Encoder(nn.Module):
    def __init__(self, num_features, kernel, reduction, recieve_csff_ip = False):
        super(Encoder, self).__init__()
        self.encoder_level1 = nn.Sequential(
            nn.Conv2d(num_features, num_features, kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features, num_features, kernel, padding = same),
            nn.ReLU()
        )
        self.encoder_level2 = nn.Sequential(
            nn.Conv2d(num_features, num_features + num_features // reduction, kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features + num_features // reduction, num_features + num_features // reduction, kernel, padding = same),
            nn.ReLU()
        )
        self.encoder_level3 = nn.Sequential(
            nn.Conv2d(num_features + num_features // reduction, num_features + 2 * (num_features // reduction), kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features + 2 * (num_features // reduction), num_features + 2 * (num_features // reduction), kernel, padding = same),
            nn.ReLU()
        )
        self.down12 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
        self.down23 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
        
        if recieve_csff_ip:
            self.csff1_1 = nn.Conv2d(num_features, num_features, 1)
            self.csff1_2 = nn.Conv2d(num_features, num_features, 1)
            self.csff2_1 = nn.Conv2d(num_features + num_features // reduction, num_features + num_features // reduction, 1)
            self.csff2_2 = nn.Conv2d(num_features , num_features + num_features // reduction, 1)
            self.csff3_1 = nn.Conv2d( num_features + 2 * (num_features // reduction),  num_features + 2 * (num_features // reduction), 1)
            self.csff3_2 = nn.Conv2d( num_features + 1 * (num_features // reduction),  num_features + 2 * (num_features // reduction), 1)


    def forward(self, x, csff_input1_1 = None, csff_input1_2 = None, csff_input2_1 = None, csff_input2_2 = None, csff_input3_1 = None, csff_input3_2 = None):
        
        enc1 = self.encoder_level1(x)
        if csff_input1_1 is not None and csff_input1_2 is not None:
            #print(f"Shape of csff_input1_1 is {csff_input1_1.shape}, csff_input1_2 is {csff_input1_2.shape}")
            enc1 = enc1 + self.csff1_1(csff_input1_1) + self.csff1_2(csff_input1_2)
            
        x = self.down12(enc1)
        
        enc2 = self.encoder_level2(x)
        if csff_input2_1 is not None and csff_input2_2 is not None:
            #print(f"Shape of csff_input2_1 is {csff_input2_1.shape}, csff_input2_2 is {csff_input2_2.shape}")
            enc2 = enc2 + self.csff2_1(csff_input2_1) + self.csff2_2(csff_input2_2)
        
        x = self.down23(enc2)
        enc3 = self.encoder_level3(x)
        if csff_input3_1 is not None and csff_input3_2 is not None:
            enc3 = enc3 + self.csff3_1(csff_input3_1) + self.csff3_2(csff_input3_2)
            
        return [enc1, enc2, enc3]


class Decoder(nn.Module):
    def __init__(self, num_features, kernel, reduction):
        super(Decoder, self).__init__()
        self.decoder_level1 = nn.Sequential(
            nn.Conv2d(num_features, num_features, kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features, num_features, kernel, padding = same),
            nn.ReLU()
        )
        self.decoder_level2 = nn.Sequential(
            nn.Conv2d(num_features + num_features // reduction, num_features + num_features // reduction, kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features + num_features // reduction, num_features, kernel, padding = same),
            nn.ReLU()
        )
        self.decoder_level3 = nn.Sequential(
            nn.Conv2d(num_features + 2*(num_features // reduction), num_features + 2*(num_features // reduction), kernel, padding = same),
            nn.ReLU(),
            nn.Conv2d(num_features + 2*(num_features // reduction), num_features + 1*(num_features // reduction), kernel, padding = same),
            nn.ReLU()
        )
        self.up32 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.up21 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, enc1, enc2,enc3):
        dec3 = self.decoder_level3(enc3)
        x3 = self.up32(dec3)+enc2
        dec2 = self.decoder_level2(x3)
        x2 = self.up21(dec2)+enc1
        dec1 = self.decoder_level1(x2)
        return [dec1, dec2, dec3]
    


class MPRNET(nn.Module):
    def __init__(self, num_features = 80, kernel_size = 3, reduction =4 , num_cabs = 8 ):
        super(MPRNET, self).__init__()
        self.conv_stage1 = nn.Conv2d(3, num_features, kernel_size, padding = 'same')
        self.conv_stage2 = nn.Conv2d(3, num_features, kernel_size, padding = 'same')
        self.conv_stage3 = nn.Conv2d(3, num_features, kernel_size, padding = 'same')
        self.final_conv = nn.Conv2d(num_features, 3, kernel_size, padding = 'same')
        
        
        self.sam_stage1 = SupervisedAttentionModule(num_features)
        self.sam_stage2 = SupervisedAttentionModule(num_features)
        
        self.cab_stage1 = ChannelAttentionBlock(num_features, kernel_size, reduction)
        self.cab_stage2 = ChannelAttentionBlock(num_features, kernel_size, reduction)
        self.cab_stage3 = ChannelAttentionBlock(num_features, kernel_size, reduction)
        
        self.encoder1 = Encoder(num_features, kernel_size, reduction)
        self.encoder2 = Encoder(num_features, kernel_size, reduction, True)
        
        self.decoder1 = Decoder(num_features, kernel_size, reduction)
        self.decoder2 = Decoder(num_features, kernel_size, reduction)
        
        self.ors_net = OriginalResolutionSubNetwork(num_features, kernel_size, reduction, num_cabs)
        
    def forward(self,image):
        
        height = image.shape[-2]
        width = image.shape[-1]
        
        left_image = image[:,:,:,0:int(width/2)]
        right_image = image[:,:,:,int(width/2):width]
        
        top_left_image = left_image[:,:,0:int(height/2),:]
        bottom_left_image = left_image[:,:,int(height/2):height,:]
        top_right_image = right_image[:,:,0:int(height/2),:]
        bottom_right_image= right_image[:,:,int(height/2):height,:]
        
        x1_tl = self.conv_stage1(top_left_image)
        x1_tl = self.cab_stage1(x1_tl)
        enc1_stage1_tl, enc2_stage1_tl, enc3_stage1_tl = self.encoder1(x1_tl)
        
        
        x1_bl = self.conv_stage1(bottom_left_image)
        x1_bl = self.cab_stage1(x1_bl)
        enc1_stage1_bl, enc2_stage1_bl, enc3_stage1_bl = self.encoder1(x1_bl)
        
        enc1_stage1_left, enc2_stage1_left, enc3_stage1_left = torch.cat((enc1_stage1_tl,enc1_stage1_bl), dim = 2), torch.cat((enc2_stage1_tl,enc2_stage1_bl), dim = 2), torch.cat((enc3_stage1_tl,enc3_stage1_bl), dim = 2)
        
        x1_tr = self.conv_stage1(top_right_image)
        x1_tr = self.cab_stage1(x1_tr)
        enc1_stage1_tr, enc2_stage1_tr, enc3_stage1_tr = self.encoder1(x1_tr)
        
        x1_br = self.conv_stage1(bottom_right_image)
        x1_br = self.cab_stage1(x1_br)
        enc1_stage1_br, enc2_stage1_br, enc3_stage1_br = self.encoder1(x1_br)
        
        enc1_stage1_right, enc2_stage1_right, enc3_stage1_right = torch.cat((enc1_stage1_tr,enc1_stage1_br), dim = 2), torch.cat((enc2_stage1_tr,enc2_stage1_br), dim = 2), torch.cat((enc3_stage1_tr,enc3_stage1_br), dim = 2)
        
        dec1_stage1_left, dec2_stage1_left, dec3_stage1_left = self.decoder1(enc1_stage1_left, enc2_stage1_left, enc3_stage1_left)
        dec1_stage1_right, dec2_stage1_right, dec3_stage1_right = self.decoder1(enc1_stage1_right, enc2_stage1_right, enc3_stage1_right)
        
        f_out_stage1_left, x_s1_left = self.sam_stage1(dec1_stage1_left, left_image)
        f_out_stage1_right, x_s1_right = self.sam_stage1(dec1_stage1_right, right_image)
        
        x_s1 = torch.cat((x_s1_left, x_s1_right), dim = 3)
        
        x2_left = self.conv_stage2(left_image)
        x2_left = self.cab_stage2(x2_left)
        enc1_stage2_left, enc2_stage2_left, enc3_stage2_left = self.encoder2(x2_left, enc1_stage1_left, dec1_stage1_left, enc2_stage1_left, dec2_stage1_left, enc3_stage1_left, dec3_stage1_left)
        
        x2_right = self.conv_stage2(right_image)
        x2_right = self.cab_stage2(x2_right)
        enc1_stage2_right, enc2_stage2_right, enc3_stage2_right = self.encoder2(x2_right, enc1_stage1_right, dec1_stage1_right, enc2_stage1_right, dec2_stage1_right, enc3_stage1_right, dec3_stage1_right)
        
        enc1_stage2, enc2_stage2, enc3_stage2 = torch.cat((enc1_stage2_left, enc1_stage2_right), dim= 3), torch.cat((enc2_stage2_left, enc2_stage2_right), dim= 3), torch.cat((enc3_stage2_left, enc3_stage2_right), dim= 3)
        
        dec1_stage2, dec2_stage2, dec3_stage2 = self.decoder2(enc1_stage2, enc2_stage2, enc3_stage2)
        
        f_out_stage2, x_s2 = self.sam_stage2(dec1_stage2, image)
        
        x3 = self.conv_stage3(image)
        x3 = self.cab_stage3(x3)
        x3 = x3 + f_out_stage2
        x3 = self.ors_net(x3, enc1_stage2, dec1_stage2, enc2_stage2, dec2_stage2, enc3_stage2, dec3_stage2)
        x3 = self.final_conv(x3)
        x_s3 = x3 + image
        
        
        return x_s1, x_s2, x_s3
        
        
        
        

In [2]:
TRAIN_DIR = "/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/Gopro/train"
VAL_DIR = "/kaggle/input/a-curated-list-of-image-deblurring-datasets/DBlur/Gopro/test"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

import os
import random
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

to_tensor = transforms.ToTensor()

class GoProDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.blur_dir = os.path.join(root_dir, 'blur')
        self.sharp_dir = os.path.join(root_dir, 'sharp')
        self.image_files = os.listdir(self.blur_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        blur_image_path = os.path.join(self.blur_dir, self.image_files[idx])
        sharp_image_path = os.path.join(self.sharp_dir, self.image_files[idx])

        blur_image = Image.open(blur_image_path)
        sharp_image = Image.open(sharp_image_path)

        blur_image = to_tensor(blur_image)
        sharp_image = to_tensor(sharp_image)
            
        
        """hh, ww = tar_img.shape[1], tar_img.shape[2]

        rr     = random.randint(0, hh-ps)
        cc     = random.randint(0, ww-ps)
        aug    = random.randint(0, 8)

        # Crop patch
        inp_img = inp_img[:, rr:rr+ps, cc:cc+ps]
        tar_img = tar_img[:, rr:rr+ps, cc:cc+ps]"""
            
        patch_size = 256
        
        
        height = blur_image.shape[1]
        width = blur_image.shape[2]
        
        
        patch_start_y = random.randint(0, height - patch_size)
        patch_start_x = random.randint(0, width - patch_size)
        
        blur_image = blur_image[: , patch_start_y : patch_start_y + patch_size, patch_start_x : patch_start_x + patch_size]
        sharp_image = sharp_image[: , patch_start_y : patch_start_y + patch_size, patch_start_x : patch_start_x + patch_size]
        

        return blur_image, sharp_image
    
transform = transforms.Compose([
    transforms.ToTensor(),
    
])

train_dataset = GoProDataset(root_dir = TRAIN_DIR, transform=transform)


TrainLoader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=16, drop_last=False, pin_memory=True)

val_dataset = GoProDataset(root_dir = VAL_DIR)

ValidationLoader = DataLoader(val_dataset, batch_size = 2, shuffle = True, num_workers=8, drop_last=False, pin_memory=True)

Using device: cuda




In [3]:
'''Code for these losses are taken directly from the authors of the paper'''

class CharbonnierLoss(nn.Module):
    """Charbonnier Loss (L1)"""

    def __init__(self, eps=1e-3):
        super(CharbonnierLoss, self).__init__()
        self.eps = eps

    def forward(self, x, y):
        diff = x - y
        # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
        loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
        return loss

class EdgeLoss(nn.Module):
    def __init__(self):
        super(EdgeLoss, self).__init__()
        k = torch.Tensor([[.05, .25, .4, .25, .05]])
        self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
        if torch.cuda.is_available():
            self.kernel = self.kernel.cuda()
        self.loss = CharbonnierLoss()

    def conv_gauss(self, img):
        n_channels, _, kw, kh = self.kernel.shape
        img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
        return F.conv2d(img, self.kernel, groups=n_channels)

    def laplacian_kernel(self, current):
        filtered    = self.conv_gauss(current)    # filter
        down        = filtered[:,:,::2,::2]       # downsample
        new_filter  = torch.zeros_like(filtered)
        new_filter[:,:,::2,::2] = down*4          # upsample
        filtered    = self.conv_gauss(new_filter) # filter
        diff = current - filtered
        return diff

    def forward(self, x, y):
        loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
        return loss

        

# Train Loop

In [None]:
from tqdm import tqdm
import random
import time

import torch

def compute_psnr(img1, img2, max_val=1.0):
    """
    Compute the PSNR (Peak Signal-to-Noise Ratio) between two images after normalizing them.
    
    Args:
    - img1 (torch.Tensor): The first image tensor.
    - img2 (torch.Tensor): The second image tensor.
    - max_val (float): The maximum possible pixel value of the images before normalization.
    
    Returns:
    - float: The PSNR value.
    """
    
    # Normalize the images to [0, 1]
    img1 = img1 / max_val
    img2 = img2 / max_val
    
    mse = torch.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    return 10 * torch.log10(1.0 / mse)



    

learning_rate = 2e-4
weight_decay = 1e-1
cycle = 8
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

#model = MPRNET().to(device)

char_loss = CharbonnierLoss()
edge_loss = EdgeLoss()

#optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay, betas=(0.9, 0.999),eps=1e-8)
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=cycle, eta_min=1e-6)

train_losses = []
val_losses = []

from tqdm import tqdm

best_psnr = 0
best_epoch = 0

for epoch in range(start_epoch, 101):
    #epoch_start_time = time.time()
    epoch_loss = 0
    train_id = 1

    model.train()
    
    # Initialize tqdm progress bar
    pbar = tqdm(total=len(TrainLoader), desc=f"Epoch {epoch}")
    
    for i, (blur_image, sharp_image) in enumerate(TrainLoader, 0):

        # zero_grad
        for param in model.parameters():
            param.grad = None

        blur_image = blur_image.to(device)
        sharp_image = sharp_image.to(device)

        restored_images = model(blur_image)
 
        # Compute loss at each stage
        loss_char = sum([char_loss(restored_images[j],sharp_image) for j in range(len(restored_images))])
        loss_edge = sum([edge_loss(restored_images[j],sharp_image) for j in range(len(restored_images))])
        loss = (loss_char) + (0.05*loss_edge)
       
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

        # Update tqdm progress bar
        pbar.set_postfix(loss=epoch_loss/(i+1))
        pbar.update(1)

    #### Evaluation ####
    
    model.eval()
    psnr_val_rgb = []
    for ii, (blur_image, sharp_image) in enumerate(ValidationLoader, 0):
        blur_image = blur_image.to(device)
        sharp_image = sharp_image.to(device)

        with torch.no_grad():
            restored_images = model(blur_image)
        restored = restored_images[-1]

        for res, tar in zip(restored, sharp_image):
            psnr_val_rgb.append(compute_psnr(res, tar))

    psnr_val_rgb = torch.stack(psnr_val_rgb).mean().item()

    if psnr_val_rgb > best_psnr:
        best_psnr = psnr_val_rgb
        best_epoch = epoch
        torch.save({'epoch': epoch, 
                    'state_dict': model.state_dict(),
                    'optimizer' : optimizer.state_dict()
                    }, os.path.join("/kaggle/working/","model_best.pth"))

    # Update tqdm progress bar with PSNR
    pbar.set_postfix(loss=epoch_loss/len(TrainLoader), psnr=psnr_val_rgb)
    print("[epoch %d PSNR: %.4f --- best_epoch %d Best_PSNR %.4f]" % (epoch, psnr_val_rgb, best_epoch, best_psnr))

    

    scheduler.step()
    
    print("------------------------------------------------------------------")
    print("Epoch: {}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch,  epoch_loss, scheduler.get_lr()[0]))
    print("------------------------------------------------------------------")

    

    # Close the tqdm progress bar
    pbar.close()


Using device: cuda


Epoch 16:   1%|          | 13/1052 [00:14<19:56,  1.15s/it, loss=0.0753]
Epoch 16: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0921, psnr=26.6]


[epoch 16 PSNR: 26.6191 --- best_epoch 16 Best_PSNR 26.6191]
------------------------------------------------------------------
Epoch: 16	Loss: 96.9167	LearningRate 0.000192
------------------------------------------------------------------


Epoch 17: 100%|██████████| 1052/1052 [12:58<00:00,  1.35it/s, loss=0.091, psnr=26.8]


[epoch 17 PSNR: 26.7636 --- best_epoch 17 Best_PSNR 26.7636]
------------------------------------------------------------------
Epoch: 17	Loss: 95.7264	LearningRate 0.000171
------------------------------------------------------------------


Epoch 18: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0896, psnr=27]


[epoch 18 PSNR: 26.9614 --- best_epoch 18 Best_PSNR 26.9614]
------------------------------------------------------------------
Epoch: 18	Loss: 94.2581	LearningRate 0.000139
------------------------------------------------------------------


Epoch 19: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0895, psnr=27.1]


[epoch 19 PSNR: 27.0577 --- best_epoch 19 Best_PSNR 27.0577]
------------------------------------------------------------------
Epoch: 19	Loss: 94.1752	LearningRate 0.000101
------------------------------------------------------------------


Epoch 20: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0885, psnr=27.2]


[epoch 20 PSNR: 27.1827 --- best_epoch 20 Best_PSNR 27.1827]
------------------------------------------------------------------
Epoch: 20	Loss: 93.1031	LearningRate 0.000062
------------------------------------------------------------------


Epoch 21: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0868, psnr=27.1]


[epoch 21 PSNR: 27.1442 --- best_epoch 20 Best_PSNR 27.1827]
------------------------------------------------------------------
Epoch: 21	Loss: 91.3613	LearningRate 0.000030
------------------------------------------------------------------


Epoch 22: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0854, psnr=27.4]


[epoch 22 PSNR: 27.3970 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 22	Loss: 89.8188	LearningRate 0.000009
------------------------------------------------------------------


Epoch 23: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0855, psnr=27.4]


[epoch 23 PSNR: 27.3678 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 23	Loss: 89.9294	LearningRate 0.000200
------------------------------------------------------------------


Epoch 24: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0881, psnr=27.1]


[epoch 24 PSNR: 27.0505 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 24	Loss: 92.7280	LearningRate 0.000192
------------------------------------------------------------------


Epoch 25: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0867, psnr=27.1]


[epoch 25 PSNR: 27.0558 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 25	Loss: 91.2145	LearningRate 0.000171
------------------------------------------------------------------


Epoch 26: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0879, psnr=27]


[epoch 26 PSNR: 26.9676 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 26	Loss: 92.4737	LearningRate 0.000139
------------------------------------------------------------------


Epoch 27: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0862, psnr=27.1]


[epoch 27 PSNR: 27.0770 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 27	Loss: 90.6441	LearningRate 0.000101
------------------------------------------------------------------


Epoch 28: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0851, psnr=27.2]


[epoch 28 PSNR: 27.1866 --- best_epoch 22 Best_PSNR 27.3970]
------------------------------------------------------------------
Epoch: 28	Loss: 89.5390	LearningRate 0.000062
------------------------------------------------------------------


Epoch 29: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0852, psnr=27.5]


[epoch 29 PSNR: 27.4639 --- best_epoch 29 Best_PSNR 27.4639]
------------------------------------------------------------------
Epoch: 29	Loss: 89.6691	LearningRate 0.000030
------------------------------------------------------------------


Epoch 30: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0837, psnr=27.6]


[epoch 30 PSNR: 27.5901 --- best_epoch 30 Best_PSNR 27.5901]
------------------------------------------------------------------
Epoch: 30	Loss: 88.0898	LearningRate 0.000009
------------------------------------------------------------------


Epoch 31: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0813, psnr=27.7]


[epoch 31 PSNR: 27.6990 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 31	Loss: 85.5287	LearningRate 0.000200
------------------------------------------------------------------


Epoch 32: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0876, psnr=27.2]


[epoch 32 PSNR: 27.2099 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 32	Loss: 92.1260	LearningRate 0.000192
------------------------------------------------------------------


Epoch 33: 100%|██████████| 1052/1052 [12:56<00:00,  1.36it/s, loss=0.086, psnr=27.2]


[epoch 33 PSNR: 27.2195 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 33	Loss: 90.5118	LearningRate 0.000171
------------------------------------------------------------------


Epoch 34: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0862, psnr=27.2]


[epoch 34 PSNR: 27.2362 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 34	Loss: 90.6630	LearningRate 0.000139
------------------------------------------------------------------


Epoch 35: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0859, psnr=27.5]


[epoch 35 PSNR: 27.4863 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 35	Loss: 90.3148	LearningRate 0.000101
------------------------------------------------------------------


Epoch 36: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0839, psnr=27.6]


[epoch 36 PSNR: 27.6340 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 36	Loss: 88.2140	LearningRate 0.000062
------------------------------------------------------------------


Epoch 37: 100%|██████████| 1052/1052 [12:56<00:00,  1.35it/s, loss=0.0815, psnr=27.6]


[epoch 37 PSNR: 27.5560 --- best_epoch 31 Best_PSNR 27.6990]
------------------------------------------------------------------
Epoch: 37	Loss: 85.7241	LearningRate 0.000030
------------------------------------------------------------------


Epoch 38: 100%|██████████| 1052/1052 [12:58<00:00,  1.35it/s, loss=0.0808, psnr=27.9]


[epoch 38 PSNR: 27.8535 --- best_epoch 38 Best_PSNR 27.8535]
------------------------------------------------------------------
Epoch: 38	Loss: 85.0278	LearningRate 0.000009
------------------------------------------------------------------


Epoch 39: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0803, psnr=27.9]


[epoch 39 PSNR: 27.9261 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 39	Loss: 84.5061	LearningRate 0.000200
------------------------------------------------------------------


Epoch 40: 100%|██████████| 1052/1052 [12:58<00:00,  1.35it/s, loss=0.0856, psnr=27.1]


[epoch 40 PSNR: 27.0541 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 40	Loss: 90.0916	LearningRate 0.000192
------------------------------------------------------------------


Epoch 41: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0866, psnr=27.5]


[epoch 41 PSNR: 27.5349 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 41	Loss: 91.0735	LearningRate 0.000171
------------------------------------------------------------------


Epoch 42: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0847, psnr=27.4]


[epoch 42 PSNR: 27.3722 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 42	Loss: 89.1218	LearningRate 0.000139
------------------------------------------------------------------


Epoch 43: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0832, psnr=27.7]


[epoch 43 PSNR: 27.6609 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 43	Loss: 87.4779	LearningRate 0.000101
------------------------------------------------------------------


Epoch 44: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0832, psnr=27.7]


[epoch 44 PSNR: 27.7382 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 44	Loss: 87.4839	LearningRate 0.000062
------------------------------------------------------------------


Epoch 45: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0814, psnr=27.9]


[epoch 45 PSNR: 27.9120 --- best_epoch 39 Best_PSNR 27.9261]
------------------------------------------------------------------
Epoch: 45	Loss: 85.6396	LearningRate 0.000030
------------------------------------------------------------------


Epoch 46: 100%|██████████| 1052/1052 [12:58<00:00,  1.35it/s, loss=0.08, psnr=28]


[epoch 46 PSNR: 27.9767 --- best_epoch 46 Best_PSNR 27.9767]
------------------------------------------------------------------
Epoch: 46	Loss: 84.2066	LearningRate 0.000009
------------------------------------------------------------------


Epoch 47: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0781, psnr=28.1]


[epoch 47 PSNR: 28.0776 --- best_epoch 47 Best_PSNR 28.0776]
------------------------------------------------------------------
Epoch: 47	Loss: 82.1090	LearningRate 0.000200
------------------------------------------------------------------


Epoch 48: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0855, psnr=27.7]


[epoch 48 PSNR: 27.6670 --- best_epoch 47 Best_PSNR 28.0776]
------------------------------------------------------------------
Epoch: 48	Loss: 89.9392	LearningRate 0.000192
------------------------------------------------------------------


Epoch 49: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0844, psnr=27.6]


[epoch 49 PSNR: 27.5885 --- best_epoch 47 Best_PSNR 28.0776]
------------------------------------------------------------------
Epoch: 49	Loss: 88.7875	LearningRate 0.000171
------------------------------------------------------------------


Epoch 50: 100%|██████████| 1052/1052 [12:57<00:00,  1.35it/s, loss=0.0844, psnr=27.7]


[epoch 50 PSNR: 27.6612 --- best_epoch 47 Best_PSNR 28.0776]
------------------------------------------------------------------
Epoch: 50	Loss: 88.7729	LearningRate 0.000139
------------------------------------------------------------------


Epoch 51: 100%|██████████| 1052/1052 [12:58<00:00,  1.35it/s, loss=0.0832, psnr=27.6]


[epoch 51 PSNR: 27.6425 --- best_epoch 47 Best_PSNR 28.0776]
------------------------------------------------------------------
Epoch: 51	Loss: 87.5285	LearningRate 0.000101
------------------------------------------------------------------


Epoch 52:  17%|█▋        | 184/1052 [02:04<09:26,  1.53it/s, loss=0.0806]

In [7]:
def load_checkpoint(model, optimizer, filename):
    # Note: Input model & optimizer should be pre-defined. This routine only updates their states.
    start_epoch = 0
    if os.path.isfile(filename):
        print("=> loading checkpoint '{}'".format(filename))
        checkpoint = torch.load(filename)
        start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(filename, checkpoint['epoch']))
    else:
        print("=> no checkpoint found at '{}'".format(filename))

    return model, optimizer, start_epoch
# Define the path to your saved checkpoint
checkpoint_path = "/kaggle/working/model_best.pth"

# Load the checkpoint
model, optimizer, start_epoch = load_checkpoint(model, optimizer, checkpoint_path)


=> loading checkpoint '/kaggle/working/model_best.pth'
=> loaded checkpoint '/kaggle/working/model_best.pth' (epoch 16)
