In [4]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import os
import random
from tqdm import tqdm
from PIL import Image
from torch.utils.data import DataLoader, TensorDataset
import torch.nn.functional as F

In [5]:
Input = '/kaggle/input/blur-dataset/motion_blurred'
Output = '/kaggle/input/blur-dataset/sharp'

In [None]:
def data_load(path):
    result = []
    for file in tqdm(sorted(os.listdir(path))):
        if any(extension in file for extension in ['.jpg', '.png', '.jpeg','JPG']):
            img = cv2.imread(os.path.join(path, file))
            if img is not None:
              img = cv2.resize(img, (128, 128))
              result.append(img)
            else:
              return
    return np.array(result)

In [None]:
X = data_load(Input)
y = data_load(Output)

In [None]:
batch_size=35
learning_rate=0.001
num_epochs=10

In [None]:
all_transforms=transforms.Compose([transforms.ToTensor()])

def apply_transforms(images_array,transform):
  transformed_images=[]
  for img in images_array:
    img=Image.fromarray(img)
    transformed_img=transform(img)
    transformed_images.append(transformed_img)

  transformed_images=torch.stack(transformed_images)
  return transformed_images

In [None]:
X=apply_transforms(X,all_transforms)
y=apply_transforms(y,all_transforms)

In [None]:
ind1=int(len(X)*0.5)
ind2=int(len(X)*0.625)

X_train=X[:ind1]
y_train=y[:ind1]
X_val=X[ind1:ind2]
y_val=X[ind1:ind2]
X_test=X[ind2:]
y_test=y[ind2:]

In [None]:
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
test_dataset = TensorDataset(X_test, y_test)

In [None]:
train_loader = torch.utils.data.DataLoader(dataset = train_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

val_loader = torch.utils.data.DataLoader(dataset = val_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

test_loader = torch.utils.data.DataLoader(dataset = test_dataset,
                                           batch_size = batch_size,
                                           shuffle = True)

In [None]:
class Encoder_block(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(Encoder_block,self).__init__()
    self.conv_layer1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=(1,1))
    self.conv_layer2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=(1,1))
    self.max_pool=nn.MaxPool2d(kernel_size=2,stride=2)
    self.relu=nn.ReLU()

  def forward(self,x):
    out=self.relu(self.conv_layer1(x))
    out=self.relu(self.conv_layer2(out))
    skip=out
    out=self.max_pool(out)
    return out,skip

In [None]:
class Decoder_block(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(Decoder_block,self).__init__()
    self.upconv_layer=nn.ConvTranspose2d(in_channels=in_channels,out_channels=out_channels,kernel_size=2,stride=2)
    self.conv_layer1=nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=3,padding=(1,1))
    self.conv_layer2=nn.Conv2d(in_channels=out_channels,out_channels=out_channels,kernel_size=3,padding=(1,1))
    self.relu=nn.ReLU()

  def forward(self,x,skip):
    out=self.upconv_layer(x)
    if out.size() != skip.size():
         out = F.interpolate(out, size=(skip.size(2), skip.size(3)), mode='bilinear', align_corners=True)
            
    out=torch.cat((out,skip),dim=1)
    out=self.relu(self.conv_layer1(out))
    out=self.relu(self.conv_layer2(out))
    return out

In [None]:
class UNet(nn.Module):
  def __init__(self,in_channels,out_channels):
    super(UNet,self).__init__()
    self.enc1=Encoder_block(in_channels,64)
    self.enc2=Encoder_block(64,128)
    self.enc3=Encoder_block(128,256)
    self.enc4=Encoder_block(256,512)

    self.middle_conv1=nn.Conv2d(512,1024,kernel_size=3,padding=(1,1))
    self.middle_conv2=nn.Conv2d(1024,1024,kernel_size=3,padding=(1,1))
    self.relu=nn.ReLU()

    self.dec1=Decoder_block(1024,512)
    self.dec2=Decoder_block(512,256)
    self.dec3=Decoder_block(256,128)
    self.dec4=Decoder_block(128,64)

    self.final=nn.Conv2d(64,out_channels,kernel_size=1)

  def forward(self,x):
    x1,skip1=self.enc1(x)
    x2,skip2=self.enc2(x1)
    x3,skip3=self.enc3(x2)
    x4,skip4=self.enc4(x3)

    x5=self.relu(self.middle_conv1(x4))
    x6=self.relu(self.middle_conv2(x5))

    x7=self.dec1(x6,skip4)
    x8=self.dec2(x7,skip3)
    x9=self.dec3(x8,skip2)
    x10=self.dec4(x9,skip1)

    out=self.final(x10)
    return out

In [None]:
model=UNet(in_channels=3,out_channels=3)
print(model)

In [None]:
# Set Loss function with criterion
criterion = nn.CrossEntropyLoss()

# Set optimizer with optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, weight_decay = 0.005, momentum = 0.9)  

total_step = len(train_loader)

In [None]:
# We use the pre-defined number of epochs to determine how many iterations to train the network on
for epoch in range(num_epochs):
    #Load in the data in batches using the train_loader object
    for i, (images, labels) in enumerate(train_loader): 
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        print(f"Before squeeze - outputs shape: {outputs.shape}, labels shape: {labels.shape}")
        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))

In [None]:
image = X_train[1]

# Preprocess the image
transform = transforms.Compose([
    transforms.Resize((128, 128))  # Resize to match model's input size
])
input_image = transform(image).unsqueeze(0)  # Add batch dimension

model.eval()  # Set model to evaluation mode

# Perform inference
with torch.no_grad():
    outputs = model(input_image)

outputs=outputs.squeeze()
outputs=outputs.permute(1,2,0)
output_image = outputs.squeeze().cpu().numpy()   

# Display the input image and segmented output
plt.figure(figsize=(10, 5))

plt.subplot(1, 2, 1)
plt.imshow(image.permute(1, 2, 0))  
plt.title('Input Image')
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(output_image)
plt.title('Output')
plt.axis('off')

plt.show()