In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
import wandb
import random
import intel_extension_for_pytorch as ipex
import os.path
import random
import torch.utils.data as data
import numpy as np
import sys
import torch.nn.functional as F


from torch.autograd import Variable
from tqdm import tqdm
from PIL import Image
from torchvision import transforms
from torch import save
from math import log10
from tensorboardX import SummaryWriter
from torch.autograd import Variable
from math import exp

In [1]:
def default_conv(in_channels, out_channels, kernel_size, bias):
    return nn.Conv2d(
        in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias
    )


class UpConv(nn.Module):
    def __init__(self):
        super(UpConv, self).__init__()
        self.body = nn.Sequential(
            default_conv(3, 12, 3, True),
            nn.PixelShuffle(2),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.body(x)


class ResidualBlock(nn.Module):
    def __init__(self, n_feats):
        super(ResidualBlock, self).__init__()
        modules_body = [
            default_conv(n_feats, n_feats, 3, bias=True),
            nn.ReLU(inplace=True),
            default_conv(n_feats, n_feats, 3, bias=True)
        ]
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res


class SingleScaleNet(nn.Module):
    def __init__(self, n_feats, n_resblocks, is_skip, n_channels=3):
        super(SingleScaleNet, self).__init__()
        self.is_skip = is_skip

        modules_head = [
            default_conv(n_channels, n_feats, 5, bias=True),
            nn.ReLU(inplace=True)]

        modules_body = [
            ResidualBlock(n_feats)
            for _ in range(n_resblocks)
        ]

        modules_tail = [default_conv(n_feats, 3, 5, bias=True)]

        self.head = nn.Sequential(*modules_head)
        self.body = nn.Sequential(*modules_body)
        self.tail = nn.Sequential(*modules_tail)

    def forward(self, x):
        x = self.head(x)

        res = self.body(x)
        if self.is_skip:
            res += x

        res = self.tail(res)

        return res


class MultiScaleNet(nn.Module):
    def __init__(self, n_feats, n_resblocks, is_skip):
        super(MultiScaleNet, self).__init__()

        self.scale3_net = SingleScaleNet(n_feats, n_resblocks, is_skip, n_channels=3)
        self.upconv3 = UpConv()

        self.scale2_net = SingleScaleNet(n_feats, n_resblocks, is_skip, n_channels=6)
        self.upconv2 = UpConv()

        self.scale1_net = SingleScaleNet(n_feats, n_resblocks, is_skip, n_channels=6)

    def forward(self, mulscale_input):
        input_b1, input_b2, input_b3 = mulscale_input

        output_l3 = self.scale3_net(input_b3)
        output_l3_up = self.upconv3(output_l3)

        output_l2 = self.scale2_net(torch.cat((input_b2, output_l3_up), 1))
        output_l2_up = self.upconv2(output_l2)

        output_l1 = self.scale2_net(torch.cat((input_b1, output_l2_up), 1))

        return output_l1, output_l2, output_l3

In [2]:

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
    return gauss / gauss.sum()


def create_window(window_size, channel):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
    return window


def _ssim(img1, img2, window, window_size, channel, size_average=True):
    mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
    mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)

    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2

    sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2

    C1 = 0.01 ** 2
    C2 = 0.03 ** 2

    ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))

    if size_average:
        return ssim_map.mean()
    else:
        return ssim_map.mean(1).mean(1).mean(1)


class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.channel = 1
        self.window = create_window(window_size, self.channel)

    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()

        if channel == self.channel and self.window.data.type() == img1.data.type():
            window = self.window
        else:
            window = create_window(self.window_size, channel)

            if img1.is_cuda:
                window = window.cuda(img1.get_device())
            window = window.type_as(img1)

            self.window = window
            self.channel = channel

        return _ssim(img1, img2, window, self.window_size, channel, self.size_average)


def ssim(img1, img2, window_size=11, size_average=True):
    (_, channel, _, _) = img1.size()
    window = create_window(window_size, channel)

    if img1.is_cuda:
        window = window.cuda(img1.get_device())
    window = window.type_as(img1)

    return _ssim(img1, img2, window, window_size, channel, size_average)


In [3]:

def tensor_to_rgb(img_input):
    output = img_input.cpu()
    output = output.data.squeeze(0)

    output = output.numpy()
    output *= 255.0
    output = output.clip(0, 255)

    return output


def compute_psnr(img1, img2):
    mse = ((img1 - img2) ** 2).mean()
    psnr = 10 * log10(255 * 255 / (mse + 10 ** (-10)))
    return psnr


class SaveData():
    def __init__(self, save_dir, exp_name, finetuning):
        self.save_dir = os.path.join(save_dir, exp_name)

        if not finetuning:
            if os.path.exists(self.save_dir):
                os.system('rm -rf ' + self.save_dir)
            os.makedirs(self.save_dir)
        else:
            if not os.path.exists(self.save_dir):
                os.makedirs(self.save_dir)

        self.save_dir_model = os.path.join(self.save_dir, 'model')
        if not os.path.exists(self.save_dir_model):
            os.makedirs(self.save_dir_model)

        self.logFile = open(self.save_dir + '/log.txt', 'a')

        save_dir_tensorboard = os.path.join(self.save_dir, 'logs')
        if not os.path.exists(save_dir_tensorboard):
            os.makedirs(save_dir_tensorboard)
        self.writer = SummaryWriter(save_dir_tensorboard)


    def save_params(self, args):
        with open(self.save_dir + '/params.txt', 'w') as params_file:
            params_file.write(str(args.__dict__) + "\n\n")


    def save_model(self, model, epoch):
        torch.save(model.state_dict(), self.save_dir_model + '/model_lastest.pt')
        torch.save(model.state_dict(), self.save_dir_model + '/model_' + str(epoch) + '.pt')
        torch.save(model, self.save_dir_model + '/model_obj.pt')
        torch.save(epoch, self.save_dir_model + '/last_epoch.pt')

    def save_log(self, log):
        sys.stdout.flush()
        self.logFile.write(log + '\n')
        self.logFile.flush()

    def load_model(self, model):
        model.load_state_dict(torch.load(self.save_dir_model + '/model_lastest.pt'))
        last_epoch = torch.load(self.save_dir_model + '/last_epoch.pt')
        print("Load mode_status from {}/model_lastest.pt, epoch: {}".format(self.save_dir_model, last_epoch))
        return model, last_epoch

    def add_scalar(self, tag, value, step):
        self.writer.add_scalar(tag, value, step)


In [4]:



def augment(img_input, img_target):
    degree = random.choice([0, 90, 180, 270])
    img_input = transforms.functional.rotate(img_input, degree)
    img_target = transforms.functional.rotate(img_target, degree)

    # color augmentation
    img_input = transforms.functional.adjust_gamma(img_input, 1)
    img_target = transforms.functional.adjust_gamma(img_target, 1)
    sat_factor = 1 + (0.2 - 0.4 * np.random.rand())
    img_input = transforms.functional.adjust_saturation(img_input, sat_factor)
    img_target = transforms.functional.adjust_saturation(img_target, sat_factor)

    return img_input, img_target


def getPatch(img_input, img_target, path_size):
    w, h = img_input.size
    p = path_size
    x = random.randrange(0, w - p + 1)
    y = random.randrange(0, h - p + 1)
    img_input = img_input.crop((x, y, x + p, y + p))
    img_target = img_target.crop((x, y, x + p, y + p))
    return img_input, img_target


class Gopro(data.Dataset):
    def __init__(self, data_dir, patch_size=256, is_train=False, multi=True):
        super(Gopro, self).__init__()
        self.is_train = is_train
        self.patch_size = patch_size
        self.multi = multi

        self.sharp_file_paths = []

        sub_folders = os.listdir(data_dir)

        for folder_name in sub_folders:
            sharp_sub_folder = os.path.join(data_dir, folder_name, 'sharp')
            sharp_file_names = os.listdir(sharp_sub_folder)

            for file_name in sharp_file_names:
                sharp_file_path = os.path.join(sharp_sub_folder, file_name)
                self.sharp_file_paths.append(sharp_file_path)

        self.n_samples = len(self.sharp_file_paths)

    def get_img_pair(self, idx):
        sharp_file_path = self.sharp_file_paths[idx]
        blur_file_path = sharp_file_path.replace("sharp", "blur")

        img_input = Image.open(blur_file_path).convert('RGB')
        img_target = Image.open(sharp_file_path).convert('RGB')

        return img_input, img_target

    def __getitem__(self, idx):
        img_input, img_target = self.get_img_pair(idx)

        if self.is_train:
            img_input, img_target = getPatch(img_input, img_target, self.patch_size)
            img_input, img_target = augment(img_input, img_target)

        input_b1 = transforms.ToTensor()(img_input)
        target_s1 = transforms.ToTensor()(img_target)

        H = input_b1.size()[1]
        W = input_b1.size()[2]

        if self.multi:
            input_b1 = transforms.ToPILImage()(input_b1)
            target_s1 = transforms.ToPILImage()(target_s1)

            input_b2 = transforms.ToTensor()(transforms.Resize([int(H / 2), int(W / 2)])(input_b1))
            input_b3 = transforms.ToTensor()(transforms.Resize([int(H / 4), int(W / 4)])(input_b1))

            if self.is_train:
                target_s2 = transforms.ToTensor()(transforms.Resize([int(H / 2), int(W / 2)])(target_s1))
                target_s3 = transforms.ToTensor()(transforms.Resize([int(H / 4), int(W / 4)])(target_s1))
            else:
                target_s2 = []
                target_s3 = []

            input_b1 = transforms.ToTensor()(input_b1)
            target_s1 = transforms.ToTensor()(target_s1)
            return {'input_b1': input_b1, 'input_b2': input_b2, 'input_b3': input_b3,
                    'target_s1': target_s1, 'target_s2': target_s2, 'target_s3': target_s3}
        else:
            return {'input_b1': input_b1, 'target_s1': target_s1}

    def __len__(self):
        return self.n_samples


In [6]:
data_dir = '/home/kalyan/DataSets/GOPRO_Large/train/'
save_dir = './result'
patch_size = 256
batch_size = 8
val_data_dir = None
n_threads = 8
exp_name = 'Net1'
finetuning = False
multi = False
skip = False
n_resblocks = 9
n_feats = 64
lr = 1e-4
epochs = 40
lr_step_size = 600
lr_gamma = 0.1
period = 1
gpu = 0

In [7]:
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"] = "0"
!export 'PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512'

def get_dataset(data_dir, patch_size=None, batch_size=1, n_threads=8, is_train=False, multi=False):
    dataset = Gopro(data_dir, patch_size=patch_size, is_train=is_train, multi=multi)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                             drop_last=True, shuffle=is_train, num_workers=int(n_threads))
    return dataloader

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)


/bin/zsh: /opt/miniconda3/lib/libncursesw.so.6: no version information available (required by /bin/zsh)


In [8]:

def validation(model, dataloader, multi):
    total_psnr = 0
    for batch, images in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            input_b1 = Variable(images['input_b1'].cuda())
            target_s1 = Variable(images['target_s1'].cuda())

            if multi:
                input_b2 = Variable(images['input_b2'].cuda())
                input_b3 = Variable(images['input_b3'].cuda())
                output_l1, _, _ = model((input_b1, input_b2, input_b3))
            else:
                output_l1 = model(input_b1)

        output_l1 = tensor_to_rgb(output_l1)
        target_s1 = tensor_to_rgb(target_s1)

        psnr = compute_psnr(target_s1, output_l1)
        total_psnr += psnr

    return total_psnr / (batch + 1)


In [9]:
# start a new wandb run to track this script

wandb.init(
    # set the wandb project where this run will be logged
    project="DeBlurring")


def train():
    print("Training started.")
    
    if multi:
        my_model = MultiScaleNet(n_feats=n_feats, n_resblocks=n_resblocks, is_skip=skip)
    else:
        my_model = SingleScaleNet(n_feats=n_feats, n_resblocks=n_resblocks, is_skip=skip)
    my_model = my_model.cuda()
    loss_function = nn.MSELoss().cuda()
    optimizer = optim.Adam(my_model.parameters(), lr=lr)
    scheduler = lr_scheduler.StepLR(optimizer, lr_step_size, lr_gamma)

    my_model , optimizer = ipex.optimize(model=my_model,optimizer=optimizer)
    # utility for saving models, parameters, and logs
    #save = SaveData(save_dir, exp_name, finetuning)
    #save.save_params(locals())  # Save local variables as parameters
    #num_params = count_parameters(my_model)
    #save.save_log(str(num_params))

    # load pre-trained model if provided
    last_epoch = -1
    if finetuning:
        my_model, last_epoch = save.load_model(my_model)
    start_epoch = last_epoch + 1

    # load dataset
    data_loader = get_dataset(data_dir, patch_size=patch_size, batch_size=batch_size,
                              n_threads=n_threads, is_train=True, multi=multi)
    if val_data_dir:
        valid_data_loader = get_dataset(val_data_dir, n_threads=n_threads, multi=multi)

    for epoch in range(start_epoch, epochs):
        print("* Epoch {}/{}".format(epoch + 1, epochs))

        scheduler.step()
        learning_rate = optimizer.param_groups[0]['lr']
        total_loss = 0

        for batch, images in tqdm(enumerate(data_loader)):
            input_b1 = Variable(images['input_b1'].cuda())
            target_s1 = Variable(images['target_s1'].cuda())

            if multi:
                input_b2 = Variable(images['input_b2'].cuda())
                target_s2 = Variable(images['target_s2'].cuda())
                input_b3 = Variable(images['input_b3'].cuda())
                target_s3 = Variable(images['target_s3'].cuda())
                output_l1, output_l2, output_l3 = my_model((input_b1, input_b2, input_b3))
                loss = (loss_function(output_l1, target_s1)
                        + loss_function(output_l2, target_s2)
                        + loss_function(output_l3, target_s3)) / 3
            else:
                output_l1 = my_model(input_b1)
                loss = loss_function(output_l1, target_s1)

            my_model.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.data.cpu().numpy()

        loss = total_loss / (batch + 1)
        wandb.log({"train_loss": loss, "epoch": epoch})


        if epoch % period == 0:
            if val_data_dir:
                my_model.eval()
                psnr = validation(my_model, valid_data_loader, multi)
                my_model.train()

                log = "Epoch {}/{} \t Learning rate: {:.5f} \t Train total_loss: {:.5f} \t * Val PSNR: {:.2f}\n".format(
                    epoch + 1, epochs, learning_rate, loss, psnr)
                print(log)
                save.save_log(log)
                wandb.log({'valid/psnr': psnr, "epoch": epoch})
                
        else:
            log = "Epoch {}/{} \t Learning rate: {:.5f} \t Train total_loss: {:.5f}\n".format(
                epoch + 1, epochs, learning_rate, loss)
            print(log)
            save.save_log(log)

if __name__ == '__main__':
    train()


2023-08-15 02:37:19,997 - torch.distributed.nn.jit.instantiator - INFO - Created a temporary directory at /tmp/tmpcd9j7d8c
2023-08-15 02:37:19,999 - torch.distributed.nn.jit.instantiator - INFO - Writing /tmp/tmpcd9j7d8c/_remote_module_non_scriptable.py
2023-08-15 02:37:20.400176: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-08-15 02:37:22,371 - numexpr.utils - INFO - Note: NumExpr detected 12 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.
2023-08-15 02:37:22,372 - numexpr.utils - INFO - NumExpr defaulting to 8 threads.
2023-08-15 02:37:24,690 - wandb.jupyter - ERROR - Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016670254616716798, max=1.0…

Training started.
* Epoch 1/40


262it [04:54,  1.12s/it]

* Epoch 2/40



262it [05:04,  1.16s/it]

* Epoch 3/40



262it [05:04,  1.16s/it]

* Epoch 4/40



262it [05:03,  1.16s/it]

* Epoch 5/40



262it [05:03,  1.16s/it]

* Epoch 6/40



262it [05:03,  1.16s/it]

* Epoch 7/40



262it [05:03,  1.16s/it]

* Epoch 8/40



262it [05:04,  1.16s/it]

* Epoch 9/40



262it [05:03,  1.16s/it]

* Epoch 10/40



262it [05:03,  1.16s/it]

* Epoch 11/40



262it [05:04,  1.16s/it]

* Epoch 12/40



262it [05:03,  1.16s/it]

* Epoch 13/40



262it [05:03,  1.16s/it]

* Epoch 14/40



262it [05:03,  1.16s/it]

* Epoch 15/40



262it [05:03,  1.16s/it]

* Epoch 16/40



262it [05:03,  1.16s/it]

* Epoch 17/40



262it [05:03,  1.16s/it]

* Epoch 18/40



262it [05:04,  1.16s/it]

* Epoch 19/40



262it [05:03,  1.16s/it]

* Epoch 20/40



262it [05:03,  1.16s/it]

* Epoch 21/40



262it [05:03,  1.16s/it]

* Epoch 22/40



262it [05:03,  1.16s/it]

* Epoch 23/40



262it [05:03,  1.16s/it]

* Epoch 24/40



262it [05:03,  1.16s/it]

* Epoch 25/40



262it [05:03,  1.16s/it]

* Epoch 26/40



262it [05:03,  1.16s/it]

* Epoch 27/40



262it [05:03,  1.16s/it]

* Epoch 28/40



262it [05:03,  1.16s/it]

* Epoch 29/40



262it [05:04,  1.16s/it]

* Epoch 30/40



262it [05:04,  1.16s/it]

* Epoch 31/40



262it [05:03,  1.16s/it]

* Epoch 32/40



262it [05:03,  1.16s/it]

* Epoch 33/40



262it [05:03,  1.16s/it]

* Epoch 34/40



262it [05:03,  1.16s/it]

* Epoch 35/40



262it [05:03,  1.16s/it]

* Epoch 36/40



262it [05:03,  1.16s/it]

* Epoch 37/40



262it [05:03,  1.16s/it]

* Epoch 38/40



262it [05:03,  1.16s/it]

* Epoch 39/40



262it [05:03,  1.16s/it]

* Epoch 40/40
