In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
!ls "/content/drive/My Drive/Colab Notebooks/CV2/Project"

build.ipynb	common.py  Model   models.py	test.ipynb  utility.py	video.py
common_data.py	data	   models  __pycache__	trainer.py  utils.py


In [3]:

cd drive/My Drive/Colab Notebooks/CV2/Project

/content/drive/My Drive/Colab Notebooks/CV2/Project


In [4]:
!ls

build.ipynb	common.py  Model   models.py	test.ipynb  utility.py	video.py
common_data.py	data	   models  __pycache__	trainer.py  utils.py


In [5]:
class Args:
    def __init__(self):
        self.n_resblocks = 32  # Adjust this if necessary
        self.n_feats = 256     # Adjusted to match the pretrained model's configuration
        self.scale = [4]       # Scale factor used in the pretrained model
        self.n_colors = 3      # Number of color channels
        self.rgb_range = 255   # RGB range
        self.res_scale = 0.1     # Residual scaling factor

# Instantiate the args with updated parameters
args = Args()


In [None]:
import common
import torch

import torch.nn as nn

url = {
    'r16f64x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x2-1bc95232.pt',
    'r16f64x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x3-abf2a44e.pt',
    'r16f64x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_baseline_x4-6b446fab.pt',
    'r32f256x2': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x2-0edfb8a3.pt',
    'r32f256x3': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x3-ea3ef2c6.pt',
    'r32f256x4': 'https://cv.snu.ac.kr/research/EDSR/models/edsr_x4-4f62e9ef.pt'
}

def make_model(args, parent=False):
    return EDSR(args)

class EDSR(nn.Module):
    def __init__(self, args, conv=common.default_conv):
        super(EDSR, self).__init__()

        n_resblocks = args.n_resblocks
        n_feats = args.n_feats
        kernel_size = 3
        scale = args.scale[0]
        act = nn.ReLU(True)
        url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale)
        if url_name in url:
            self.url = url[url_name]
        else:
            self.url = None
        self.sub_mean = common.MeanShift(args.rgb_range)
        self.add_mean = common.MeanShift(args.rgb_range, sign=1)

        # define head module
        m_head = [conv(args.n_colors, n_feats, kernel_size)]

        # define body module
        m_body = [
            common.ResBlock(
                conv, n_feats, kernel_size, act=act, res_scale=args.res_scale
            ) for _ in range(n_resblocks)
        ]
        m_body.append(conv(n_feats, n_feats, kernel_size))

        # define tail module
        m_tail = [
            common.Upsampler(conv, scale, n_feats, act=False),
            conv(n_feats, args.n_colors, kernel_size)
        ]

        self.head = nn.Sequential(*m_head)
        self.body = nn.Sequential(*m_body)
        self.tail = nn.Sequential(*m_tail)

    def forward(self, x):
        x = self.sub_mean(x)
        x = self.head(x)

        res = self.body(x)
        res += x

        x = self.tail(res)
        x = self.add_mean(x)

        return x

    def load_state_dict(self, state_dict, strict=True):
        own_state = self.state_dict()
        for name, param in state_dict.items():
            if name in own_state:
                if isinstance(param, nn.Parameter):
                    param = param.data
                try:
                    own_state[name].copy_(param)
                except Exception:
                    if name.find('tail') == -1:
                        raise RuntimeError('While copying the parameter named {}, '
                                           'whose dimensions in the model are {} and '
                                           'whose dimensions in the checkpoint are {}.'
                                           .format(name, own_state[name].size(), param.size()))
            elif strict:
                if name.find('tail') == -1:
                    raise KeyError('unexpected key "{}" in state_dict'
                                   .format(name))

In [None]:
# Assuming you have the model architecture and path to the state dictionary
model = EDSR(args)
state_dict = torch.load('/content/drive/My Drive/Colab Notebooks/CV2/Project/Model/EDSR_x4.pt')
model.load_state_dict(state_dict)
# Check if GPU is available and move the model to GPU if it is
if torch.cuda.is_available():
    model = model.cuda()  # Move model to GPU

model.eval()  # Set the model to evaluation mode


In [None]:
from PIL import Image
import torchvision.transforms as transforms
import imageio
import common_data

# Function to load and preprocess the image
def load_image(image_path, img_size=(64, 112)):
    # image = Image.open(image_path).convert('RGB')  # Convert image to RGB
    # # Resize and transform to tensor
    # transform = transforms.Compose([
    #     transforms.Resize(img_size),
    #     transforms.ToTensor(),
    # ])
    # image = transform(image)
    # # Normalize the image (These values might need adjustment based on the model training)
    # image = (image - 0.5) / 0.5
    # return image.unsqueeze(0)  # Add a batch dimension
    lr = imageio.imread(image_path)
    lr, = common_data.set_channel(lr, n_channels=3)
    lr_t, = common_data.np2Tensor(lr, rgb_range=255)
    return lr_t



# Function to convert a tensor to a PIL image
def tensor_to_pil(tensor):
    image = Image.fromarray(tensor.numpy(), 'RGB')
    return image

In [None]:
# Up scale

import utility
for i in range(7):
  # Load image
  image_path = 'data/input/im' + str(i+1) + '.png'
  input_image = load_image(image_path)

  # Assuming 'model' is your loaded EDSR model and is already on the correct device
  model.eval()  # Ensure the model is in evaluation mode

  # Check if GPU is available and move the input image to GPU if it is
  if torch.cuda.is_available():
      input_image = input_image.cuda()

  with torch.no_grad():  # No need to track gradients
      output_image = model(input_image)

  # The output might need to be denormalized depending on how the model was trained
  # output_image = (output_image * 0.5) + 0.5
  # output_image = output_image.squeeze(0).cpu().clamp(0, 1)  # Remove batch dimension and clamp values
  output_image = utility.quantize(output_image, 255)
  tensor_cpu = output_image.byte().permute(1, 2, 0).cpu()

  output_pil_image = tensor_to_pil(tensor_cpu)
  output_file_path = 'data/frame_seq/im' + str(i+1) + '.png'
  output_pil_image.save(output_file_path, 'PNG')




  lr = imageio.imread(image_path)


In [None]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

def sub_mean(x):
    mean = x.mean(2, keepdim=True).mean(3, keepdim=True)
    x -= mean
    return x, mean

def InOutPaddings(x):
    w, h = x.size(3), x.size(2)
    padding_width, padding_height = 0, 0
    if w != ((w >> 7) << 7):
        padding_width = (((w >> 7) + 1) << 7) - w
    if h != ((h >> 7) << 7):
        padding_height = (((h >> 7) + 1) << 7) - h
    paddingInput = nn.ReflectionPad2d(padding=[padding_width // 2, padding_width - padding_width // 2,
                                               padding_height // 2, padding_height - padding_height // 2])
    paddingOutput = nn.ReflectionPad2d(padding=[0 - padding_width // 2, padding_width // 2 - padding_width,
                                                0 - padding_height // 2, padding_height // 2 - padding_height])
    return paddingInput, paddingOutput


class ConvNorm(nn.Module):
    def __init__(self, in_feat, out_feat, kernel_size, stride=1, norm=False):
        super(ConvNorm, self).__init__()

        reflection_padding = kernel_size // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv = nn.Conv2d(in_feat, out_feat, stride=stride, kernel_size=kernel_size, bias=True)

        self.norm = norm
        if norm == 'IN':
            self.norm = nn.InstanceNorm2d(out_feat, track_running_stats=True)
        elif norm == 'BN':
            self.norm = nn.BatchNorm2d(out_feat)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv(out)
        if self.norm:
            out = self.norm(out)
        return out


class UpConvNorm(nn.Module):
    def __init__(self, in_channels, out_channels, mode='transpose', norm=False):
        super(UpConvNorm, self).__init__()

        if mode == 'transpose':
            self.upconv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1)
        elif mode == 'shuffle':
            self.upconv = nn.Sequential(
                ConvNorm(in_channels, 4*out_channels, kernel_size=3, stride=1, norm=norm),
                PixelShuffle(2))
        else:
            # out_channels is always going to be the same as in_channels
            self.upconv = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
                ConvNorm(in_channels, out_channels, kernel_size=1, stride=1, norm=norm))

    def forward(self, x):
        out = self.upconv(x)
        return out



class meanShift(nn.Module):
    def __init__(self, rgbRange, rgbMean, sign, nChannel=3):
        super(meanShift, self).__init__()
        if nChannel == 1:
            l = rgbMean[0] * rgbRange * float(sign)

            self.shifter = nn.Conv2d(1, 1, kernel_size=1, stride=1, padding=0)
            self.shifter.weight.data = torch.eye(1).view(1, 1, 1, 1)
            self.shifter.bias.data = torch.Tensor([l])
        elif nChannel == 3:
            r = rgbMean[0] * rgbRange * float(sign)
            g = rgbMean[1] * rgbRange * float(sign)
            b = rgbMean[2] * rgbRange * float(sign)

            self.shifter = nn.Conv2d(3, 3, kernel_size=1, stride=1, padding=0)
            self.shifter.weight.data = torch.eye(3).view(3, 3, 1, 1)
            self.shifter.bias.data = torch.Tensor([r, g, b])
        else:
            r = rgbMean[0] * rgbRange * float(sign)
            g = rgbMean[1] * rgbRange * float(sign)
            b = rgbMean[2] * rgbRange * float(sign)
            self.shifter = nn.Conv2d(6, 6, kernel_size=1, stride=1, padding=0)
            self.shifter.weight.data = torch.eye(6).view(6, 6, 1, 1)
            self.shifter.bias.data = torch.Tensor([r, g, b, r, g, b])

        # Freeze the meanShift layer
        for params in self.shifter.parameters():
            params.requires_grad = False

    def forward(self, x):
        x = self.shifter(x)

        return x


""" CONV - (BN) - RELU - CONV - (BN) """
class ResBlock(nn.Module):
    def __init__(self, in_feat, out_feat, kernel_size=3, reduction=False, bias=True, # 'reduction' is just for placeholder
                 norm=False, act=nn.ReLU(True), downscale=False):
        super(ResBlock, self).__init__()

        self.body = nn.Sequential(
            ConvNorm(in_feat, out_feat, kernel_size=kernel_size, stride=2 if downscale else 1),
            act,
            ConvNorm(out_feat, out_feat, kernel_size=kernel_size, stride=1)
        )

        self.downscale = None
        if downscale:
            self.downscale = nn.Conv2d(in_feat, out_feat, kernel_size=1, stride=2)

    def forward(self, x):
        res = x
        out = self.body(x)
        if self.downscale is not None:
            res = self.downscale(res)
        out += res

        return out


## Channel Attention (CA) Layer
class CALayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(CALayer, self).__init__()
        # global average pooling: feature --> point
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # feature channel downscale and upscale --> channel weight
        self.conv_du = nn.Sequential(
            nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True),
            nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x):
        y = self.avg_pool(x)
        y = self.conv_du(y)
        return x * y, y


## Residual Channel Attention Block (RCAB)
class RCAB(nn.Module):
    def __init__(self, in_feat, out_feat, kernel_size, reduction, bias=True,
            norm=False, act=nn.ReLU(True), downscale=False, return_ca=False):
        super(RCAB, self).__init__()

        self.body = nn.Sequential(
            ConvNorm(in_feat, out_feat, kernel_size, stride=2 if downscale else 1, norm=norm),
            act,
            ConvNorm(out_feat, out_feat, kernel_size, stride=1, norm=norm),
            CALayer(out_feat, reduction)
        )
        self.downscale = downscale
        if downscale:
            self.downConv = nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=2, padding=1)
        self.return_ca = return_ca

    def forward(self, x):
        res = x
        out, ca = self.body(x)
        if self.downscale:
            res = self.downConv(res)
        out += res

        if self.return_ca:
            return out, ca
        else:
            return out


## Residual Group (RG)
class ResidualGroup(nn.Module):
    def __init__(self, Block, n_resblocks, n_feat, kernel_size, reduction, act, norm=False):
        super(ResidualGroup, self).__init__()

        modules_body = [Block(n_feat, n_feat, kernel_size, reduction, bias=True, norm=norm, act=act)
            for _ in range(n_resblocks)]
        modules_body.append(ConvNorm(n_feat, n_feat, kernel_size, stride=1, norm=norm))
        self.body = nn.Sequential(*modules_body)

    def forward(self, x):
        res = self.body(x)
        res += x
        return res


def pixel_shuffle(input, scale_factor):
    batch_size, channels, in_height, in_width = input.size()

    out_channels = int(int(channels / scale_factor) / scale_factor)
    out_height = int(in_height * scale_factor)
    out_width = int(in_width * scale_factor)

    if scale_factor >= 1:
        input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width)
        shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous()
    else:
        block_size = int(1 / scale_factor)
        input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size)
        shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous()

    return shuffle_out.view(batch_size, out_channels, out_height, out_width)


class PixelShuffle(nn.Module):
    def __init__(self, scale_factor):
        super(PixelShuffle, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        return pixel_shuffle(x, self.scale_factor)
    def extra_repr(self):
        return 'scale_factor={}'.format(self.scale_factor)


def conv(in_channels, out_channels, kernel_size,
         stride=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        padding=kernel_size//2,
        stride=1,
        bias=bias,
        groups=groups)


def conv1x1(in_channels, out_channels, stride=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=1,
        stride=stride,
        bias=bias,
        groups=groups)

def conv3x3(in_channels, out_channels, stride=1,
            padding=1, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=3,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def conv5x5(in_channels, out_channels, stride=1,
            padding=2, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=5,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def conv7x7(in_channels, out_channels, stride=1,
            padding=3, bias=True, groups=1):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=7,
        stride=stride,
        padding=padding,
        bias=bias,
        groups=groups)

def upconv2x2(in_channels, out_channels, mode='shuffle'):
    if mode == 'transpose':
        return nn.ConvTranspose2d(
            in_channels,
            out_channels,
            kernel_size=4,
            stride=2,
            padding=1)
    elif mode == 'shuffle':
        return nn.Sequential(
            conv3x3(in_channels, 4*out_channels),
            PixelShuffle(2))
    else:
        # out_channels is always going to be the same as in_channels
        return nn.Sequential(
            nn.Upsample(mode='bilinear', scale_factor=2, align_corners=False),
            conv1x1(in_channels, out_channels))



class Interpolation(nn.Module):
    def __init__(self, n_resgroups, n_resblocks, n_feats,
                 reduction=16, act=nn.LeakyReLU(0.2, True), norm=False):
        super(Interpolation, self).__init__()

        # define modules: head, body, tail
        self.headConv = conv3x3(n_feats * 2, n_feats)

        modules_body = [
            ResidualGroup(
                RCAB,
                n_resblocks=n_resblocks,
                n_feat=n_feats,
                kernel_size=3,
                reduction=reduction,
                act=act,
                norm=norm)
            for _ in range(n_resgroups)]
        self.body = nn.Sequential(*modules_body)

        self.tailConv = conv3x3(n_feats, n_feats)

    def forward(self, x0, x1):
        # Build input tensor
        x = torch.cat([x0, x1], dim=1)
        x = self.headConv(x)

        res = self.body(x)
        res += x

        out = self.tailConv(res)
        return out


class Interpolation_res(nn.Module):
    def __init__(self, n_resgroups, n_resblocks, n_feats,
                 act=nn.LeakyReLU(0.2, True), norm=False):
        super(Interpolation_res, self).__init__()

        # define modules: head, body, tail (reduces concatenated inputs to n_feat)
        self.headConv = conv3x3(n_feats * 2, n_feats)

        modules_body = [ResidualGroup(ResBlock, n_resblocks=n_resblocks, n_feat=n_feats, kernel_size=3,
                            reduction=0, act=act, norm=norm)
                        for _ in range(n_resgroups)]
        self.body = nn.Sequential(*modules_body)

        self.tailConv = conv3x3(n_feats, n_feats)

    def forward(self, x0, x1):
        # Build input tensor
        x = torch.cat([x0, x1], dim=1)
        x = self.headConv(x)

        res = x
        for m in self.body:
            res = m(res)
        res += x

        x = self.tailConv(res)

        return x

In [None]:
import math
import numpy as np

import torch
import torch.nn as nn



class Encoder(nn.Module):
    def __init__(self, in_channels=3, depth=3):
        super(Encoder, self).__init__()

        # Shuffle pixels to expand in channel dimension
        # shuffler_list = [PixelShuffle(0.5) for i in range(depth)]
        # self.shuffler = nn.Sequential(*shuffler_list)
        self.shuffler = PixelShuffle(1 / 2**depth)

        relu = nn.LeakyReLU(0.2, True)

        # FF_RCAN or FF_Resblocks
        self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu)

    def forward(self, x1, x2):
        """
        Encoder: Shuffle-spread --> Feature Fusion --> Return fused features
        """
        feats1 = self.shuffler(x1)
        feats2 = self.shuffler(x2)

        feats = self.interpolate(feats1, feats2)

        return feats


class Decoder(nn.Module):
    def __init__(self, depth=3):
        super(Decoder, self).__init__()

        # shuffler_list = [PixelShuffle(2) for i in range(depth)]
        # self.shuffler = nn.Sequential(*shuffler_list)
        self.shuffler = PixelShuffle(2**depth)

    def forward(self, feats):
        out = self.shuffler(feats)
        return out


class CAIN(nn.Module):
    def __init__(self, depth=3):
        super(CAIN, self).__init__()

        self.encoder = Encoder(in_channels=3, depth=depth)
        self.decoder = Decoder(depth=depth)

    def forward(self, x1, x2):
        x1, m1 = sub_mean(x1)
        x2, m2 = sub_mean(x2)

        if not self.training:
            paddingInput, paddingOutput = InOutPaddings(x1)
            x1 = paddingInput(x1)
            x2 = paddingInput(x2)

        feats = self.encoder(x1, x2)
        out = self.decoder(feats)

        if not self.training:
            out = paddingOutput(out)

        mi = (m1 + m2) / 2
        out += mi

        return out, feats

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Assuming you have the model architecture and path to the state dictionary
cain = CAIN(3)
cain = torch.nn.DataParallel(cain).to(device)
checkpoint = torch.load('Model/pretrained_cain.pth')
cain.load_state_dict(checkpoint['state_dict'])
# Check if GPU is available and move the model to GPU if it is
if torch.cuda.is_available():
    cain = cain.cuda()  # Move model to GPU

cain.eval()  # Set the model to evaluation mode



DataParallel(
  (module): CAIN(
    (encoder): Encoder(
      (shuffler): PixelShuffle(scale_factor=0.125)
      (interpolate): Interpolation(
        (headConv): Conv2d(384, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (body): Sequential(
          (0): ResidualGroup(
            (body): Sequential(
              (0): RCAB(
                (body): Sequential(
                  (0): ConvNorm(
                    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
                    (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1))
                  )
                  (1): LeakyReLU(negative_slope=0.2, inplace=True)
                  (2): ConvNorm(
                    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
                    (conv): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1))
                  )
                  (3): CALayer(
                    (avg_pool): AdaptiveAvgPool2d(output_size=1)
                    (conv_du): Sequential(
         

In [None]:
import os
import sys
import time
import copy
import shutil
import random

import torch
import numpy as np
from tqdm import tqdm

# import config
import utils


class Args:
    def __init__(self):
        self.dataset = 'custom'
        self.start_epoch = checkpoint['epoch'] + 1
        self.data_root = 'data/frame_seq'
        self.img_fmt = 'png'
        self.batch_size = 32
        self.test_batch_size = 16
        self.model = 'cain'
        self.depth = 3
        self.mode = 'test'
        self.num_workers = 4

# Instantiate the args with updated parameters
args = Args()


def test(args, epoch):
    print('Evaluating for epoch = %d' % epoch)
    ##### Load Dataset #####
    test_loader = utils.load_dataset(
        args.dataset, args.data_root, args.batch_size, args.test_batch_size, args.num_workers, img_fmt=args.img_fmt)
    cain.eval()

    t = time.time()
    with torch.no_grad():
        for i, (images, meta) in enumerate(tqdm(test_loader)):

            # Build input batch
            im1, im2 = images[0].to(device), images[1].to(device)

            # Forward
            out, _ = cain(im1, im2)

            # Save result images
            if args.mode == 'test':
                for b in range(images[0].size(0)):
                    paths = meta['imgpath'][0][b].split('/')
                    fp = args.data_root
                    fp = os.path.join(fp, paths[-1][:-4])   # remove '.png' extension

                    # Decide float index
                    i1_str = paths[-1][:-4]
                    i2_str = meta['imgpath'][1][b].split('/')[-1][:-4]
                    try:
                        i1 = float(i1_str.split('_')[-1])
                    except ValueError:
                        i1 = 0.0
                    try:
                        i2 = float(i2_str.split('_')[-1])
                        if i2 == 0.0:
                            i2 = 1.0
                    except ValueError:
                        i2 = 1.0
                    fpos = max(0, fp.rfind('_'))
                    fInd = (i1 + i2) / 2
                    savepath = "%s_%06f.%s" % (fp[:fpos], fInd, args.img_fmt)
                    utils.save_image(out[b], savepath)

    # Print progress
    # print('im_processed: {:d}/{:d} {:.3f}s   \r'.format(i + 1, len(test_loader), time.time() - t))

    return


num_iter = 2 # x2**num_iter interpolation
for _ in range(num_iter):

    # run test
    test(args, args.start_epoch)

Evaluating for epoch = 175
[6] images ready to be loaded


100%|██████████| 1/1 [00:00<00:00,  1.52it/s]


Evaluating for epoch = 175
[12] images ready to be loaded


100%|██████████| 1/1 [00:01<00:00,  1.00s/it]
