<a href="https://colab.research.google.com/github/Rudrabha/SS2021-19-08-2021/blob/main/Image_Super_Resolve_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**Import Headers**

In [None]:
import os
!pip install wget
import wget
import shutil
import glob
import cv2
import numpy as np
import random
from tqdm import tqdm

Collecting wget
  Using cached wget-3.2.zip (10 kB)
Building wheels for collected packages: wget
  Building wheel for wget (setup.py) ... [?25l[?25hdone
  Created wheel for wget: filename=wget-3.2-py3-none-any.whl size=9672 sha256=6eb5cd85d3fbfe1e96dabbe39c8a28f536a2abb966c35589a63c0a95342dd8a2
  Stored in directory: /root/.cache/pip/wheels/a1/b6/7c/0e63e34eb06634181c63adacca38b79ff8f35c37e3c13e3c02
Successfully built wget
Installing collected packages: wget
Successfully installed wget-3.2


In [None]:
from torch.utils.data import Dataset, DataLoader
import torch 
from torch import nn
from torch.nn import functional as F
from torch import optim

In [None]:
use_cuda = torch.cuda.is_available()
print('use_cuda: {}'.format(use_cuda))
device = torch.device("cuda" if use_cuda else "cpu")
!nvidia-smi

use_cuda: True
Wed Aug 18 10:09:07 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   72C    P8    32W / 149W |      3MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+------------------------------------------------------------------------

**Setting up Data Path**

In [None]:
#shutil.rmtree("/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data")

In [None]:
parent_folder = "/content/IMAGE_SUPER_RESOLVE_DATA"

if os.path.isdir(parent_folder):
    shutil.rmtree(parent_folder)
os.mkdir(parent_folder)

#Create Folder to download Raw Data
raw_data_folder = os.path.join(parent_folder,"raw_data")
extracted_data_folder = os.path.join(parent_folder,"extracted_data")

if not os.path.isdir(raw_data_folder):
    os.mkdir(raw_data_folder)

if not os.path.isdir(extracted_data_folder):
    os.mkdir(extracted_data_folder)

image_data_folder = os.path.join(extracted_data_folder, "images")

**Downloading Data**

In [None]:
dataset_link = "https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz"
raw_data = os.path.join(raw_data_folder, "images.tar.gz")
print("Downloading Data")
wget.download(dataset_link, raw_data)
print("Downloading Done")

Downloading Data
Downloading Done


**Extracting the Data**

In [None]:
shutil.unpack_archive(raw_data, extracted_data_folder)

**Listing the Dataset Features**

In [None]:
image_address_list = []
image_address_list = glob.glob(os.path.join(image_data_folder,"*.jpg"))
for img_addr in image_address_list:
    try :
        img = cv2.imread(img_addr)
        # print(img.shape)
    except :
        image_address_list.remove(img_addr)

In [None]:
print(len(image_address_list))

7390


**MODULE_1 : Data Loader**

In [None]:
class DataGenerator(Dataset):
	
    def __init__(self, image_list):
        self.files = image_list
        

    def __len__(self):
        return len(self.files)
        

    def __getitem__(self,idx):

        #print(files[idx])
        img = cv2.imread(self.files[idx])
        high_res_img = cv2.resize(img,(512,512))
        high_res_img = np.transpose(high_res_img, (2, 0, 1))
        low_res_img = cv2.resize(img,(128,128))
        low_res_img = cv2.resize(low_res_img, (512, 512))
        low_res_img = np.transpose(low_res_img, (2, 0, 1))
        return torch.FloatTensor(high_res_img/255.), torch.FloatTensor(low_res_img/255.)
		
	
def load_data(image_list, batch_size=32, num_workers=10, shuffle=True):

    dataset = DataGenerator(image_list)
    data_loader = DataLoader(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)

    return data_loader

**MODULE 2 : Model Creation**

**Conv2D**

In [None]:
class Conv2d(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()
        self.residual = residual

    def forward(self, x):
        out = self.conv_block(x)
        if self.residual:
            out += x
        return self.act(out)

**Conv2D-T**

In [None]:
class Conv2dTranspose(nn.Module):
    def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv_block = nn.Sequential(
                            nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
                            nn.BatchNorm2d(cout)
                            )
        self.act = nn.ReLU()

    def forward(self, x):
        out = self.conv_block(x)
        return self.act(out)

**MODEL**

In [None]:
class Image_Super_Resolve(nn.Module):
    def __init__(self):
        super(Image_Super_Resolve, self).__init__()

        self.image_encoder = nn.Sequential(
            Conv2d(3, 4, kernel_size=3, stride=1, padding=1),
            
            Conv2d(4, 8, kernel_size=3, stride=1, padding=1),
            Conv2d(8, 8, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(8, 8, kernel_size=3, stride=1, padding=1, residual=True),
            
            Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
            Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(16, 16, kernel_size=3, stride=1, padding=1, residual=True),
            
            Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),      
            )
        self.image_decoder = nn.Sequential(

            # Conv2dTranspose(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),

            # Conv2dTranspose(32, 32, kernel_size=3, stride=2, padding=1, output_padding=1),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
            Conv2d(32, 3, kernel_size=3, stride=1, padding=1),
            )
        
    def forward(self, face_image):

        #print("Shape : ",face_image.shape)
        face_embedding = self.image_encoder(face_image)
        # print("Shape : ",face_embedding.shape)
        decoded_face = self.image_decoder(face_embedding)
        decoded_face += face_image

        decoded_face = torch.sigmoid(decoded_face)
        # print("Shape : ",decoded_face.shape)
        return decoded_face


In [None]:
model = Image_Super_Resolve()
data = torch.rand(8, 3, 512, 512)
print(data.shape)
decoded_data = model.forward(data)
print (decoded_data.shape)

torch.Size([8, 3, 512, 512])


In [None]:
class PSNR:
    """Peak Signal to Noise Ratio
    img1 and img2 have range [0, 255]"""

    def __init__(self):
        self.name = "PSNR"

    @staticmethod
    def __call__(img1, img2):
        mse = torch.mean((img1 - img2) ** 2)
        return 20 * torch.log10(255.0 / torch.sqrt(mse))


In [None]:
class SSIM:
    """Structure Similarity
    img1, img2: [0, 255]"""

    def __init__(self):
        self.name = "SSIM"

    @staticmethod
    def __call__(img1, img2):
        if not img1.shape == img2.shape:
            raise ValueError("Input images must have the same dimensions.")
        if img1.ndim == 2:  # Grey or Y-channel image
            return self._ssim(img1, img2)
        elif img1.ndim == 3:
            if img1.shape[2] == 3:
                ssims = []
                for i in range(3):
                    ssims.append(ssim(img1, img2))
                return np.array(ssims).mean()
            elif img1.shape[2] == 1:
                return self._ssim(np.squeeze(img1), np.squeeze(img2))
        else:
            print("Dimension : ",img1.ndim)
            raise ValueError("Wrong input image dimensions.")

    @staticmethod
    def _ssim(img1, img2):
        C1 = (0.01 * 255) ** 2
        C2 = (0.03 * 255) ** 2

        img1 = img1.astype(np.float64)
        img2 = img2.astype(np.float64)
        kernel = cv2.getGaussianKernel(11, 1.5)
        window = np.outer(kernel, kernel.transpose())

        mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
        mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
        mu1_sq = mu1 ** 2
        mu2_sq = mu2 ** 2
        mu1_mu2 = mu1 * mu2
        sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
        sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
        sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2

        ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
        return ssim_map.mean()

**MODULE 3 : Training**

In [None]:
def train_epoch(train_loader, model, optimizer, n_epoch):
	
    for epoch in range(n_epoch):
        progress_bar = tqdm(enumerate(train_loader))
        total_loss = 0.0
        for step, (high_res_img, low_res_img) in progress_bar:
            
            model.train()
            optimizer.zero_grad()
            
            high_res_img = high_res_img.cuda()
            low_res_img = low_res_img.cuda()

            pred_img = model.forward(low_res_img)
            
            mse = nn.MSELoss()
            psnr = PSNR()
            ssim = SSIM()

            mse_loss = mse(pred_img, high_res_img)
            psnr_loss = psnr(pred_img, high_res_img)
            #ssim_loss = ssim(pred_img, high_res_img)

            loss = mse_loss
            
            #print(loss)
            loss.backward()
            optimizer.step()

            #progress_bar.set_description("MSE : {} PSNR : {} SSIM : {}".format(mse_loss, psnr_loss, ssim_loss))
            progress_bar.set_description("MSE : {} ".format(loss))
            #print("High Res : {} Low Res : {} Pred Res : {}".format(high_res_img.shape, low_res_img.shape, pred_img.shape))
            #break

In [None]:
def main():

    print(image_address_list)
    train_loader = load_data(image_address_list, batch_size=8, num_workers=2, shuffle=True)
    model = Image_Super_Resolve()
    model = model.cuda()
    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],lr=0.01)
    n_epoch = 100
    train_epoch(train_loader, model, optimizer, n_epoch)

main()

['/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/english_cocker_spaniel_2.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/havanese_146.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/saint_bernard_164.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/Siamese_26.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/english_cocker_spaniel_135.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/wheaten_terrier_55.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/keeshond_137.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/leonberger_115.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/Birman_168.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/Russian_Blue_17.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/Ragdoll_176.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/images/chihuahua_45.jpg', '/content/IMAGE_SUPER_RESOLVE_DATA/extracted_data/

MSE : 0.04324343800544739 : : 30it [00:09,  3.14it/s]


KeyboardInterrupt: ignored

In [2]:
    def __getitem__(self,idx):

        #print(files[idx])
        try:
        	img = cv2.imread(self.files[idx])
            high_res_img = cv2.resize(img,(512,512))
            high_res_img = np.transpose(high_res_img, (2, 0, 1))
            low_res_img = cv2.resize(img,(128,128))
            low_res_img = np.transpose(low_res_img, (2, 0, 1))
            return torch.FloatTensor(high_res_img/255.0), torch.FloatTensor(low_res_img/255.0)
		except :
			None, None

TabError: ignored