In [None]:
!pip install pytorch_msssim

Collecting pytorch_msssim
  Downloading pytorch_msssim-1.0.0-py3-none-any.whl.metadata (8.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch->pytorch_msssim)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch->pytorch_msssim)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch->pytorch_msssim)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch->pytorch_msssim)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch->pytorch_msssim)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch->pytorch_msssim)
  Us

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image

def apply_1x1_convolution(image_tensor, in_channels, out_channels):
    """
    Applies a 1x1 convolution to an image tensor.

    Parameters:
        image_tensor (torch.Tensor): Input image tensor with shape (N, C, H, W)
        in_channels (int): Number of input channels in the image tensor
        out_channels (int): Number of output channels (feature maps) after convolution

    Returns:
        torch.Tensor: Output feature map after 1x1 convolution
    """
    if image_tensor.ndim != 4:
        raise ValueError("Input tensor must have 4 dimensions (N, C, H, W)")

    conv1x1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1)
    feature_map = conv1x1(image_tensor)

    return feature_map

def load_and_preprocess_image(image_path):
    """
    Loads and preprocesses an image without resizing.

    Parameters:
        image_path (str): Path to the image file

    Returns:
        torch.Tensor: Preprocessed image tensor
    """
    image = Image.open(image_path).convert('RGB')

    transform = transforms.Compose([
        transforms.ToTensor(),          # Convert image to tensor
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize image
    ])

    image_tensor = transform(image).unsqueeze(0)  # Add batch dimension

    return image_tensor

def process_images(reference_image_path, current_image_path, in_channels, out_channels):
    """
    Processes two images by applying 1x1 convolution to the reference image, and concatenates all three.

    Parameters:
        reference_image_path (str): Path to the reference image file
        current_image_path (str): Path to the current image file
        in_channels (int): Number of input channels in the images
        out_channels (int): Number of output channels for the convolution

    Returns:
        torch.Tensor: Concatenated tensor of reference image, current image, and feature map
    """
    # Load and preprocess images
    reference_image_tensor = load_and_preprocess_image(reference_image_path)
    current_image_tensor = load_and_preprocess_image(current_image_path)

    print(f"Reference image tensor shape: {reference_image_tensor.shape}")
    print(f"Current image tensor shape: {current_image_tensor.shape}")

    # Apply 1x1 convolution to reference image
    feature_map = apply_1x1_convolution(reference_image_tensor, in_channels, out_channels)

    print(f"Feature map shape: {feature_map.shape}")

    # Concatenate reference image, current image, and feature map along the channel dimension
    concatenated_tensor = torch.cat((reference_image_tensor, current_image_tensor, feature_map), dim=1)

    print(f"Concatenated tensor shape: {concatenated_tensor.shape}")

    return concatenated_tensor

# Example usage
reference_image_path = '/content/drive/MyDrive/ImageSuperResolutionData/image_0.png'  # Replace with your reference image path
current_image_path = '/content/drive/MyDrive/ImageSuperResolutionData/image_1.png'      # Replace with your current image path

# Set number of input channels (3 for RGB images) and output channels for the convolution
in_channels = 3
out_channels = 10

# Process images and get the concatenated tensor
concatenated_tensor = process_images(reference_image_path, current_image_path, in_channels, out_channels)


Reference image tensor shape: torch.Size([1, 3, 256, 256])
Current image tensor shape: torch.Size([1, 3, 256, 256])
Feature map shape: torch.Size([1, 10, 256, 256])
Concatenated tensor shape: torch.Size([1, 16, 256, 256])


In [64]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import os

class ConvolutionalUnit(nn.Module):
    def __init__(self, in_channels, out_channels, use_multi_scale=True):
        super(ConvolutionalUnit, self).__init__()

        if use_multi_scale:
            self.single_scale_conv = None
            self.multi_scale_convs = nn.ModuleList([
                nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels),
                nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels),
                nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
            ])
        else:
            self.single_scale_conv = nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
            self.multi_scale_convs = None

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // 16)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // 16, in_channels)
        self.sigmoid = nn.Sigmoid()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        if self.multi_scale_convs:
            conv_results = [conv(x) for conv in self.multi_scale_convs]
            x = sum(conv_results)
        else:
            x = self.single_scale_conv(x)

        se = self.global_avg_pool(x)
        se = se.view(se.size(0), -1)
        se = self.fc1(se)
        se = self.relu(se)
        se = self.fc2(se)
        se = self.sigmoid(se)
        se = se.view(se.size(0), se.size(1), 1, 1)
        x = x * se

        x = self.conv1x1(x)
        return x


class SmallEncoderDecoderNet(nn.Module):
    def __init__(self, in_channels, out_channels, use_multi_scale=True):
        super(SmallEncoderDecoderNet, self).__init__()

        self.encoder_conv1 = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.encoder_conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.encoder_conv3 = nn.Conv2d(128, 256, kernel_size=1)

        self.convolutional_unit = ConvolutionalUnit(256, 256, use_multi_scale=use_multi_scale)

        self.conv_before_shuffle = nn.Conv2d(256, 256, kernel_size=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
        self.decoder_conv1x1 = nn.Conv2d(256 // 4, out_channels, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.encoder_conv1(x))
        x = F.relu(self.encoder_conv2(x))
        x = F.relu(self.encoder_conv3(x))

        x = self.convolutional_unit(x)

        x = self.conv_before_shuffle(x)
        x = self.pixel_shuffle(x)
        x = self.decoder_conv1x1(x)

        return x


class ImageProcessor:
    def __init__(self, model, in_channels=3, intermediate_out_channels=10, final_out_channels=64):
        self.model = model
        self.in_channels = in_channels
        self.intermediate_out_channels = intermediate_out_channels
        self.final_out_channels = final_out_channels

    def load_and_preprocess_image(self, image_path):
        image = Image.open(image_path).convert('RGB')
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(image).unsqueeze(0)  # Add batch dimension

    def apply_1x1_convolution(self, image_tensor):
        conv1x1 = nn.Conv2d(self.in_channels, self.intermediate_out_channels, kernel_size=1)
        return conv1x1(image_tensor)

    def process_images(self, reference_image_tensor, current_image_tensor):
        feature_map = self.apply_1x1_convolution(reference_image_tensor)
        concatenated_tensor = torch.cat((reference_image_tensor, current_image_tensor, feature_map), dim=1)
        return concatenated_tensor

    def process_burst_images(self, folder_path):
        image_files = sorted([f for f in os.listdir(folder_path) if f.endswith('.png') or f.endswith('.jpg') or f.endswith('.jpeg')])

        reference_image_path = os.path.join(folder_path, image_files[0])
        reference_image_tensor = self.load_and_preprocess_image(reference_image_path)

        output_tensors = []

        for current_image_filename in image_files[1:]:
            current_image_path = os.path.join(folder_path, current_image_filename)
            current_image_tensor = self.load_and_preprocess_image(current_image_path)
            concatenated_tensor = self.process_images(reference_image_tensor, current_image_tensor)
            output_tensor = self.model(concatenated_tensor)
            output_tensors.append(output_tensor.squeeze(0))

        stacked_output_tensors = torch.stack(output_tensors)
        return output_tensors, stacked_output_tensors


class BurstImagePipeline:
    def __init__(self, folder_path, model, in_channels=3, intermediate_out_channels=10, final_out_channels=64):
        self.folder_path = folder_path
        self.processor = ImageProcessor(model, in_channels, intermediate_out_channels, final_out_channels)

    def run(self):
        return self.processor.process_burst_images(self.folder_path)


# Example usage
folder_path = '/content/drive/MyDrive/ImageSuperResolutionData/'
model = SmallEncoderDecoderNet(in_channels=16, out_channels=64, use_multi_scale=True)

pipeline = BurstImagePipeline(folder_path, model)
output_tensors, stacked_output_tensors = pipeline.run()

print(f"Stacked output tensor shape: {stacked_output_tensors.shape}")  # Should be [13, 64, 256, 256]


Stacked output tensor shape: torch.Size([13, 64, 256, 256])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from utils.metrics import PSNR

psnr_fn = PSNR(boundary_ignore=40)
seed_everything(13)

##############################################################################################
######################### Residual Global Context Attention Block ##########################################
##############################################################################################

class RGCAB(nn.Module):
    def __init__(self, num_features, num_rcab, reduction):
        super(RGCAB, self).__init__()
        self.module = [RGCA(num_features, reduction) for _ in range(num_rcab)]
        self.module.append(nn.Conv2d(num_features, num_features, kernel_size=3, padding=1, bias=False))
        self.module = nn.Sequential(*self.module)

    def forward(self, x):
        return x + self.module(x)

class RGCA(nn.Module):
    def __init__(self, n_feat, reduction=8, bias=False, act=nn.LeakyReLU(negative_slope=0.2,inplace=True), groups =1):
        super(RGCA, self).__init__()

        self.n_feat = n_feat
        self.groups = groups
        self.reduction = reduction

        modules_body = [nn.Conv2d(n_feat, n_feat, 3,1,1 , bias=bias, groups=groups), act, nn.Conv2d(n_feat, n_feat, 3,1,1 , bias=bias, groups=groups)]
        self.body = nn.Sequential(*modules_body)

        self.gcnet = nn.Sequential(GCA(n_feat, n_feat))
        self.conv1x1 = nn.Conv2d(n_feat, n_feat, kernel_size=1, bias=bias)

    def forward(self, x):
        res = self.body(x)
        res = self.gcnet(res)
        res = self.conv1x1(res)
        res += x
        return res

######################### Global Context Attention ##########################################

class GCA(nn.Module):
    def __init__(self, inplanes, planes, act=nn.LeakyReLU(negative_slope=0.2,inplace=True), bias=False):
        super(GCA, self).__init__()

        self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1, bias=bias)
        self.softmax = nn.Softmax(dim=2)

        self.channel_add_conv = nn.Sequential(
            nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias),
            act,
            nn.Conv2d(planes, inplanes, kernel_size=1, bias=bias)
        )

    def spatial_pool(self, x):
        batch, channel, height, width = x.size()
        input_x = x.view(batch, channel, height * width)
        input_x = input_x.unsqueeze(1)
        context_mask = self.conv_mask(x).view(batch, 1, height * width)
        context_mask = self.softmax(context_mask).unsqueeze(3)
        context = torch.matmul(input_x, context_mask).view(batch, channel, 1, 1)
        return context

    def forward(self, x):
        context = self.spatial_pool(x)
        channel_add_term = self.channel_add_conv(context)
        x = x + channel_add_term
        return x

##############################################################################################
######################### Multi-scale Feature Extractor ##########################################
##############################################################################################

class UpSample(nn.Module):
    def __init__(self, in_channels, chan_factor, bias=False):
        super(UpSample, self).__init__()
        self.up = nn.Sequential(
            nn.Conv2d(in_channels, int(in_channels/chan_factor), 1, stride=1, padding=0, bias=bias),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        )

    def forward(self, x):
        return self.up(x)

class DownSample(nn.Module):
    def __init__(self, in_channels, chan_factor, bias=False):
        super(DownSample, self).__init__()
        self.down = nn.Sequential(
            nn.AvgPool2d(2, ceil_mode=True, count_include_pad=False),
            nn.Conv2d(in_channels, int(in_channels*chan_factor), 1, stride=1, bias=bias)
        )

    def forward(self, x):
        return self.down(x)

class MSF(nn.Module):
    def __init__(self, in_channels=64, reduction=8, bias=False):
        super(MSF, self).__init__()

        self.feat_ext1 = nn.Sequential(*[RGCAB(in_channels, 2, reduction) for _ in range(2)])

        self.down1 = DownSample(in_channels, chan_factor=1.5)
        self.feat_ext2 = nn.Sequential(*[RGCAB(int(in_channels*1.5), 2, reduction) for _ in range(2)])

        self.down2 = DownSample(int(in_channels*1.5), chan_factor=1.5)
        self.feat_ext3 = nn.Sequential(*[RGCAB(int(in_channels*1.5*1.5), 2, reduction) for _ in range(1)])

        self.up2 = UpSample(int(in_channels*1.5*1.5), chan_factor=1.5)
        self.feat_ext5 = nn.Sequential(*[RGCAB(int(in_channels*1.5), 2, reduction) for _ in range(2)])

        self.up1 = UpSample(int(in_channels*1.5), chan_factor=1.5)
        self.feat_ext6 = nn.Sequential(*[RGCAB(in_channels, 2, reduction) for _ in range(2)])

    def forward(self, x):
        x = self.feat_ext1(x)

        enc_1 = self.down1(x)
        enc_1 = self.feat_ext2(enc_1)

        enc_2 = self.down2(enc_1)
        enc_2 = self.feat_ext3(enc_2)

        dec_2 = self.up2(enc_2)
        dec_2 = self.feat_ext5(dec_2 + enc_1)

        dec_1 = self.up1(dec_2)
        dec_2 = self.feat_ext6(dec_1 + x)

        return dec_2

##############################################################################################
######################### Adaptive Group Up-sampling Module ##########################################
##############################################################################################

class AGU(nn.Module):
    def __init__(self, in_channels, height, reduction=8, bias=False):
        super(AGU, self).__init__()

        self.height = height
        d = max(int(in_channels/reduction), 4)

        self.conv_du = nn.Sequential(
            nn.Conv2d(in_channels, d, 1, padding=0, bias=bias),
            nn.LeakyReLU(negative_slope=0.2, inplace=True)
        )

        self.convs = nn.ModuleList([nn.Conv2d(d, in_channels, kernel_size=1, stride=1, bias=bias) for _ in range(self.height)])

        self.softmax = nn.Softmax(dim=1)
        self.up = nn.ConvTranspose2d(in_channels*4, in_channels, 3, stride=2, padding=1, output_padding=1, bias=bias)

    def forward(self, inp_feats):
        batch_size, b, n_feats, H, W = inp_feats.size()

        feats_U = torch.sum(inp_feats, dim=1)
        feats_Z = self.conv_du(feats_U)

        dense_attention = [conv(feats_Z) for conv in self.convs]
        dense_attention = torch.cat(dense_attention, dim=1)

        dense_attention = dense_attention.view(batch_size, self.height, n_feats, H, W)
        dense_attention = self.softmax(dense_attention)

        feats_V = inp_feats * dense_attention
        feats_V = feats_V.view(batch_size, -1, H, W)
        feats_V = self.up(feats_V)

        return feats_V

##############################################################################################
######################### Burst Image Processing Network (BIPNet) ##########################################
##############################################################################################

# Custom Convolutional Unit Class Definition
class ConvolutionalUnit(nn.Module):
    def __init__(self, in_channels, out_channels, use_multi_scale=True):
        super(ConvolutionalUnit, self).__init__()
        if use_multi_scale:
            self.single_scale_conv = None
            self.multi_scale_convs = nn.ModuleList([
                nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels),
                nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels),
                nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels)
            ])
        else:
            self.single_scale_conv = nn.Conv2d(in_channels, in_channels, kernel_size=7, padding=3, groups=in_channels)
            self.multi_scale_convs = None

        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(in_channels, in_channels // 16)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Linear(in_channels // 16, in_channels)
        self.sigmoid = nn.Sigmoid()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        if self.multi_scale_convs:
            conv_results = [conv(x) for conv in self.multi_scale_convs]
            x = sum(conv_results)
        else:
            x = self.single_scale_conv(x)

        se = self.global_avg_pool(x)
        se = se.view(se.size(0), -1)
        se = self.fc1(se)
        se = self.relu(se)
        se = self.fc2(se)
        se = self.sigmoid(se)
        se = se.view(se.size(0), se.size(1), 1, 1)
        x = x * se

        x = self.conv1x1(x)
        return x

# Custom Small Encoder-Decoder Network Class Definition
class SmallEncoderDecoderNet(nn.Module):
    def __init__(self, in_channels, out_channels, use_multi_scale=True):
        super(SmallEncoderDecoderNet, self).__init__()
        self.encoder_conv1 = nn.Conv2d(in_channels, 64, kernel_size=5, padding=2)
        self.encoder_conv2 = nn.Conv2d(64, 128, kernel_size=5, stride=2, padding=2)
        self.encoder_conv3 = nn.Conv2d(128, 256, kernel_size=1)

        self.convolutional_unit = ConvolutionalUnit(256, 256, use_multi_scale=use_multi_scale)

        self.conv_before_shuffle = nn.Conv2d(256, 256, kernel_size=1)
        self.pixel_shuffle = nn.PixelShuffle(upscale_factor=2)
        self.decoder_conv1x1 = nn.Conv2d(256 // 4, out_channels, kernel_size=1)

    def forward(self, x):
        x = F.relu(self.encoder_conv1(x))
        x = F.relu(self.encoder_conv2(x))
        x = F.relu(self.encoder_conv3(x))

        x = self.convolutional_unit(x)

        x = self.conv_before_shuffle(x)
        x = self.pixel_shuffle(x)
        x = self.decoder_conv1x1(x)

        return x

# Custom Image Processor for Burst Alignment
class ImageProcessor:
    def __init__(self, model, in_channels=3, intermediate_out_channels=10, final_out_channels=64):
        self.model = model
        self.in_channels = in_channels
        self.intermediate_out_channels = intermediate_out_channels
        self.final_out_channels = final_out_channels

    def load_and_preprocess_image(self, image):
        transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transform(image).unsqueeze(0)  # Add batch dimension

    def apply_1x1_convolution(self, image_tensor):
        conv1x1 = nn.Conv2d(self.in_channels, self.intermediate_out_channels, kernel_size=1)
        return conv1x1(image_tensor)

    def process_images(self, reference_image_tensor, current_image_tensor):
        feature_map = self.apply_1x1_convolution(reference_image_tensor)
        concatenated_tensor = torch.cat((reference_image_tensor, current_image_tensor, feature_map), dim=1)
        return concatenated_tensor

    def process_burst_images(self, burst):
        reference_image_tensor = burst[0]

        output_tensors = []

        for current_image_tensor in burst[1:]:
            concatenated_tensor = self.process_images(reference_image_tensor, current_image_tensor)
            output_tensor = self.model(concatenated_tensor)
            output_tensors.append(output_tensor.squeeze(0))

        stacked_output_tensors = torch.stack(output_tensors)
        return stacked_output_tensors

# Custom Burst Image Pipeline
class BurstImagePipeline:
    def __init__(self, model, in_channels=3, intermediate_out_channels=10, final_out_channels=64):
        self.processor = ImageProcessor(model, in_channels, intermediate_out_channels, final_out_channels)

    def run(self, burst):
        return self.processor.process_burst_images(burst)

# Modified BIPNet Class
class BIPNet(pl.LightningModule):
    def __init__(self, num_features=64, burst_size=14, reduction=8, bias=False):
        super(BIPNet, self).__init__()

        self.train_loss = nn.L1Loss()
        self.valid_psnr = PSNR(boundary_ignore=40)

        self.conv1 = nn.Sequential(nn.Conv2d(4, num_features, kernel_size=3, padding=1, bias=bias))

        ####### Custom Burst Alignment Replacement
        self.small_encoder_decoder_net = SmallEncoderDecoderNet(in_channels=16, out_channels=num_features, use_multi_scale=True)
        self.alignment_pipeline = BurstImagePipeline(self.small_encoder_decoder_net)

        ## Feature Processing Module
        self.encoder = nn.Sequential(*[RGCAB(num_features, 3, reduction) for _ in range(3)])

        ####### Pseudo Burst Feature Fusion
        self.conv2 = nn.Sequential(nn.Conv2d(burst_size, num_features, kernel_size=3, padding=1, bias=bias))

        ## Multi-scale Feature Extraction
        self.UNet = nn.Sequential(MSF(num_features))

        ####### Adaptive Group Up-sampling
        self.SKFF1 = AGU(num_features, 4)
        self.SKFF2 = AGU(num_features, 4)
        self.SKFF3 = AGU(num_features, 4)

        ## Output Convolution
        self.conv3 = nn.Sequential(nn.Conv2d(num_features, 3, kernel_size=3, padding=1, bias=bias))

    def forward(self, burst):

        ###################
        # Input: (B, 4, H/2, W/2)
        # Output: (1, 3, 4H, 4W)
        ###################

        burst = burst[0]
        burst_feat = self.conv1(burst)  # (B, num_features, H/2, W/2)

        ##################################################
        ####### Custom Burst Feature Alignment ################
        ##################################################
        aligned_burst_feat = self.alignment_pipeline.run(burst_feat)

        ##################################################
        ####### Edge Boosting Feature Alignment ################
        ##################################################

        base_frame_feat = burst_feat[0].unsqueeze(0)
        burst_feat = self.encoder(aligned_burst_feat)

        ## Refined Aligned Feature
        burst_feat = self.feat_ext1(burst_feat)
        Residual = burst_feat - base_frame_feat
        Residual = self.cor_conv1(Residual)
        burst_feat += Residual  # (B, num_features, H/2, W/2)

        ##################################################
        ####### Pseudo Burst Feature Fusion ####################
        ##################################################
        burst_feat = burst_feat.permute(1, 0, 2, 3).contiguous()
        burst_feat = self.conv2(burst_feat)  # (num_features, num_features, H/2, W/2)

        ## Multi-scale Feature Extraction
        burst_feat = self.UNet(burst_feat)  # (num_features, num_features, H/2, W/2)


        ##################################################
        ####### Adaptive Group Up-sampling #####################
        ##################################################
        b, f, H, W = burst_feat.size()
        burst_feat = burst_feat.view(b // 4, 4, f, H, W)          # (num_features//4, 4, num_features, H/2, W/2)
        burst_feat = self.SKFF1(burst_feat)                     # (num_features//4, num_features, H, W)

        b, f, H, W = burst_feat.size()
        burst_feat = burst_feat.view(b // 4, 4, f, H, W)          # (num_features//16, 4, num_features, H, W)
        burst_feat = self.SKFF2(burst_feat)                     # (num_features//16, num_features, 2H, 2W)

        b, f, H, W = burst_feat.size()
        burst_feat = burst_feat.view(b // 4, 4, f, H, W)          # (1, 4, num_features, H, W)
        burst_feat = self.SKFF3(burst_feat)                     # (1, num_features, 4H, 4W)

        ## Output Convolution
        burst_feat = self.conv3(burst_feat)                     # (1, 3, 4H, 4W)

        return burst_feat

    def training_step(self, train_batch, batch_idx):
        x, y, flow_vectors, meta_info = train_batch
        pred = self.forward(x)
        pred = pred.clamp(0.0, 1.0)
        loss = self.train_loss(pred, y)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, val_batch, batch_idx):
        x, y, flow_vectors, meta_info = val_batch
        pred = self.forward(x)
        pred = pred.clamp(0.0, 1.0)
        PSNR = self.valid_psnr(pred, y)
        return PSNR

    def validation_epoch_end(self, outs):
        PSNR = torch.stack(outs).mean()
        self.log('val_psnr', PSNR, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-4)
        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 300, eta_min=1e-6)
        return [optimizer], [lr_scheduler]

    def optimizer_zero_grad(self, epoch, batch_idx, optimizer, optimizer_idx):
        optimizer.zero_grad(set_to_none=True)
