In [None]:
# Required imports

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.functional as TF
import torchvision
import gc
from torch.nn import init

import functools
from PIL import Image
import random
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
from pathlib import Path
from natsort import natsorted
from glob import glob, escape
import json
from collections import OrderedDict
from tqdm.notebook import tqdm
import math

from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim 
from scipy.signal import medfilt
from scipy import ndimage


In [None]:
# Main network blocks 1
# Code taken from GT-RAIN: https://github.com/UCLA-VMG/GT-RAIN

# Code modified from: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

# Basic Blocks
class Identity(nn.Module):
    def forward(self, x):
        return x

def get_norm_layer(norm_type='instance'):
    """Return a normalization layer
    Parameters:
            norm_type (str) -- the name of the normalization layer: batch | instance | none
    For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
    For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
    """
    if norm_type == 'batch':
        norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
    elif norm_type == 'instance':
        norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
    elif norm_type == 'none':
        def norm_layer(x): return Identity()
    else:
        raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
    return norm_layer

class Conv2d(torch.nn.Module):
    '''
    2D convolution class
    Args:
        in_channels : int
            number of input channels
        out_channels : int
            number of output channels
        kernel_size : int
            size of kernel
        stride : int
            stride of convolution
        activation_func : func
            activation function after convolution
        norm_layer : functools.partial
            normalization layer
        use_bias : bool
            if set, then use bias
        padding_type : str
            the name of padding layer: reflect | replicate | zero
    '''

    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
            norm_layer=nn.BatchNorm2d,
            use_bias=False,
            padding_type='reflect'):
        super(Conv2d, self).__init__()
        
        self.activation_func = activation_func
        conv_block = []
        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(kernel_size // 2)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(kernel_size // 2)]
        elif padding_type == 'zero':
            p = kernel_size // 2
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [
                nn.Conv2d(
                        in_channels, 
                        out_channels, 
                        stride=stride,
                        kernel_size=kernel_size, 
                        padding=p, 
                        bias=use_bias), 
                norm_layer(out_channels)]

        self.conv = nn.Sequential(*conv_block)

    def forward(self, x):
        conv = self.conv(x)

        if self.activation_func is not None:
            return self.activation_func(conv)
        else:
            return conv

class DeformableConv2d(nn.Module):
    '''
    2D deformable convolution class
    Args:
        in_channels : int
            number of input channels
        out_channels : int
            number of output channels
        kernel_size : int
            size of kernel
        stride : int
            stride of convolution
        padding : int
            padding
        use_bias : bool
            if set, then use bias
    '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False):

        super(DeformableConv2d, self).__init__()
        
        self.stride = stride if type(stride) == tuple else (stride, stride)
        self.padding = padding
        
        self.offset_conv = nn.Conv2d(
                in_channels, 
                2 * kernel_size * kernel_size,
                kernel_size=kernel_size, 
                stride=stride,
                padding=self.padding, 
                bias=True)

        nn.init.constant_(self.offset_conv.weight, 0.)
        nn.init.constant_(self.offset_conv.bias, 0.)
        
        self.modulator_conv = nn.Conv2d(
                in_channels, 
                1 * kernel_size * kernel_size,
                kernel_size=kernel_size, 
                stride=stride,
                padding=self.padding, 
                bias=True)

        nn.init.constant_(self.modulator_conv.weight, 0.)
        nn.init.constant_(self.modulator_conv.bias, 0.)
        
        self.regular_conv = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=stride,
                padding=self.padding,
                bias=bias)

    def forward(self, x):
        offset = self.offset_conv(x)
        modulator = 2. * torch.sigmoid(self.modulator_conv(x))
        
        x = torchvision.ops.deform_conv2d(
                input=x, 
                offset=offset, 
                weight=self.regular_conv.weight, 
                bias=self.regular_conv.bias, 
                padding=self.padding,
                mask=modulator,
                stride=self.stride)
        return x

class UpConv2d(torch.nn.Module):
    '''
    Up-convolution (upsample + convolution) block class
    Args:
        in_channels : int
            number of input channels
        out_channels : int
            number of output channels
        kernel_size : int
            size of kernel (k x k)
        activation_func : func
            activation function after convolution
        norm_layer : functools.partial
            normalization layer
        use_bias : bool
            if set, then use bias
        padding_type : str
            the name of padding layer: reflect | replicate | zero
        interpolate_mode : str
            the mode for interpolation: bilinear | nearest
    '''
    def __init__(
            self,
            in_channels,
            out_channels,
            kernel_size=3,
            activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
            norm_layer=nn.BatchNorm2d,
            use_bias=False,
            padding_type='reflect',
            interpolate_mode='bilinear'):
        
        super(UpConv2d, self).__init__()
        self.interpolate_mode = interpolate_mode

        self.conv = Conv2d(
                in_channels,
                out_channels,
                kernel_size=kernel_size,
                stride=1,
                activation_func=activation_func,
                norm_layer=norm_layer,
                use_bias=use_bias,
                padding_type=padding_type)

    def forward(self, x):
        n_height, n_width = x.shape[2:4]
        shape = (int(2 * n_height), int(2 * n_width))
        upsample = torch.nn.functional.interpolate(
                x, size=shape, mode=self.interpolate_mode, align_corners=True)
        conv = self.conv(upsample)
        return conv

class DeformableResnetBlock(nn.Module):
    """Define a Resnet block with deformable convolutions"""

    def __init__(
            self, dim, padding_type, 
            norm_layer, use_dropout, 
            use_bias, activation_func):
        """Initialize the deformable Resnet block
        A defromable resnet block is a conv block with skip connections
        """
        super(DeformableResnetBlock, self).__init__()
        self.conv_block = self.build_conv_block(
                dim, padding_type, 
                norm_layer, use_dropout, 
                use_bias, activation_func)

    def build_conv_block(
            self, dim, padding_type, 
            norm_layer, use_dropout, 
            use_bias, activation_func):
        """Construct a convolutional block.
        Parameters:
                dim (int) -- the number of channels in the conv layer.
                padding_type (str) -- the name of padding layer: reflect | replicate | zero
                norm_layer -- normalization layer
                use_dropout (bool) -- if use dropout layers.
                use_bias (bool) -- if the conv layer uses bias or not
                activation_func (func) -- activation type
        Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer)
        """
        conv_block = []

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)

        conv_block += [
                DeformableConv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), 
                norm_layer(dim), 
                activation_func]
        if use_dropout:
            conv_block += [nn.Dropout(0.5)]

        p = 0
        if padding_type == 'reflect':
            conv_block += [nn.ReflectionPad2d(1)]
        elif padding_type == 'replicate':
            conv_block += [nn.ReplicationPad2d(1)]
        elif padding_type == 'zero':
            p = 1
        else:
            raise NotImplementedError('padding [%s] is not implemented' % padding_type)
        conv_block += [DeformableConv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]

        return nn.Sequential(*conv_block)

    def forward(self, x):
        """Forward function (with skip connections)"""
        out = x + self.conv_block(x)        # add skip connections
        return out

class DecoderBlock(torch.nn.Module):
    '''
    Decoder block with skip connections
    Args:
        in_channels : int
            number of input channels
        skip_channels : int
            number of skip connection channels
        out_channels : int
            number of output channels
        activation_func : func
            activation function after convolution
        norm_layer : functools.partial
            normalization layer
        use_bias : bool
            if set, then use bias
        padding_type : str
            the name of padding layer: reflect | replicate | zero
        upsample_mode : str
            the mode for interpolation: transpose | bilinear | nearest
    '''

    def __init__(
            self,
            in_channels,
            skip_channels,
            out_channels,
            activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
            norm_layer=nn.BatchNorm2d,
            use_bias=False,
            padding_type='reflect',
            upsample_mode='transpose'):
        super(DecoderBlock, self).__init__()

        self.skip_channels = skip_channels
        self.upsample_mode = upsample_mode
        
        # Upsampling
        if upsample_mode == 'transpose':
            self.deconv = nn.Sequential(
                    nn.ConvTranspose2d(
                            in_channels, out_channels,
                            kernel_size=3, stride=2,
                            padding=1, output_padding=1,
                            bias=use_bias),
                    norm_layer(out_channels),
                    activation_func)
        else:
            self.deconv = UpConv2d(
                    in_channels, out_channels,
                    use_bias=use_bias,
                    activation_func=activation_func,
                    norm_layer=norm_layer,
                    padding_type=padding_type,
                    interpolate_mode=upsample_mode)

        concat_channels = skip_channels + out_channels
        
        self.conv = Conv2d(
                concat_channels,
                out_channels,
                kernel_size=3,
                stride=1,
                activation_func=activation_func,
                padding_type=padding_type,
                norm_layer=norm_layer,
                use_bias=use_bias)

    def forward(self, x, skip=None):
        deconv = self.deconv(x)

        if self.skip_channels > 0:
            concat = torch.cat([deconv, skip], dim=1)
        else:
            concat = deconv

        return self.conv(concat)

In [None]:
# Main network blocks 2
# Code taken from GT-RAIN: https://github.com/UCLA-VMG/GT-RAIN

def init_weights(net, init_type='normal', init_gain=0.02):
    """
    Initialize network weights.
    Parameters:
            net (network) -- network to be initialized
            init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
            init_gain (float) -- scaling factor for normal, xavier and orthogonal.
    """
    def init_func(m):    # define the initialization function
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, init_gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=init_gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=init_gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:    # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
            init.normal_(m.weight.data, 1.0, init_gain)
            init.constant_(m.bias.data, 0.0)

    print('initialize network with %s' % init_type)
    net.apply(init_func)    # apply the initialization function <init_func>

def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
    """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights
    Parameters:
                    net (network) -- the network to be initialized
                    init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal
                    gain (float) -- scaling factor for normal, xavier and orthogonal.
                    gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
    Return an initialized network.
    """
    if len(gpu_ids) > 0:
        assert(torch.cuda.is_available())
        net.to(gpu_ids[0])
        net = torch.nn.DataParallel(net, gpu_ids)        # multi-GPUs
    init_weights(net, init_type, init_gain=init_gain)
    
    # Zero for deform convs
    key_name_list = ['offset', 'modulator']
    for cur_name, parameters in net.named_parameters():
        if any(key_name in cur_name for key_name in key_name_list):
            nn.init.constant_(parameters, 0.)
    return net

class ResNetModified(nn.Module):
    """
    Resnet-based generator that consists of deformable Resnet blocks.
    """

    def __init__(
            self, 
            input_nc, 
            output_nc, 
            ngf=64, 
            norm_layer=nn.BatchNorm2d, 
            activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
            use_dropout=False, 
            n_blocks=6, 
            padding_type='reflect',
            upsample_mode='bilinear'):
        """Construct a Resnet-based generator
        Parameters:
            input_nc (int) -- the number of channels in input images
            output_nc (int) -- the number of channels in output images
            ngf (int) -- the number of filters in the last conv layer
            norm_layer -- normalization layer
            use_dropout (bool) -- if use dropout layers
            n_blocks (int) -- the number of ResNet blocks
            padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero
            upsample_mode (str) -- mode for upsampling: transpose | bilinear
        """
        assert(n_blocks >= 0)
        super(ResNetModified, self).__init__()
        if type(norm_layer) == functools.partial:
            use_bias = norm_layer.func == nn.InstanceNorm2d
        else:
            use_bias = norm_layer == nn.InstanceNorm2d

        # Initial Convolution
        self.initial_conv = nn.Sequential(
                Conv2d(
                        in_channels=input_nc,
                        out_channels=ngf,
                        kernel_size=7,
                        padding_type=padding_type,
                        norm_layer=norm_layer,
                        activation_func=activation_func,
                        use_bias=use_bias),
                Conv2d(
                        in_channels=ngf,
                        out_channels=ngf,
                        kernel_size=3,
                        padding_type=padding_type,
                        norm_layer=norm_layer,
                        activation_func=activation_func,
                        use_bias=use_bias))

        # Downsample Blocks
        n_downsampling = 2
        mult = 2 ** 0
        self.downsample_1 = Conv2d(
                in_channels=ngf * mult,
                out_channels=ngf * mult * 2,
                kernel_size=3,
                stride=2,
                padding_type=padding_type,
                norm_layer=norm_layer,
                activation_func=activation_func,
                use_bias=use_bias)
        
        mult = 2 ** 1
        self.downsample_2 = Conv2d(
                in_channels=ngf * mult,
                out_channels=ngf * mult * 2,
                kernel_size=3,
                stride=2,
                padding_type=padding_type,
                norm_layer=norm_layer,
                activation_func=activation_func,
                use_bias=use_bias)

        # Residual Blocks
        residual_blocks = []
        mult = 2 ** n_downsampling
        for i in range(n_blocks): # add ResNet blocks
            residual_blocks += [
                    DeformableResnetBlock(
                            ngf * mult, 
                            padding_type=padding_type, 
                            norm_layer=norm_layer, 
                            use_dropout=use_dropout, 
                            use_bias=use_bias, activation_func=activation_func)]
        
        self.residual_blocks = nn.Sequential(*residual_blocks)

        # Upsampling
        mult = 2 ** (n_downsampling - 0)
        self.upsample_2 = DecoderBlock(
                ngf * mult, 
                int(ngf * mult / 2),
                int(ngf * mult / 2),
                use_bias=use_bias,
                activation_func=activation_func,
                norm_layer=norm_layer,
                padding_type=padding_type,
                upsample_mode=upsample_mode)
        
        mult = 2 ** (n_downsampling - 1)
        self.upsample_1 = DecoderBlock(
                ngf * mult, 
                int(ngf * mult / 2),
                int(ngf * mult / 2),
                use_bias=use_bias,
                activation_func=activation_func,
                norm_layer=norm_layer,
                padding_type=padding_type,
                upsample_mode=upsample_mode)
        
        # Output Convolution
        self.output_conv_naive = nn.Sequential(
                nn.ReflectionPad2d(1),
                nn.Conv2d(ngf, output_nc, kernel_size=3, padding=0),
                nn.Tanh())

        # # Projection for rain robust loss
        # self.feature_projection = nn.Sequential(
        #         nn.AdaptiveAvgPool2d((2, 2)),
        #         nn.Flatten(start_dim=1, end_dim=-1))

    def forward(self, input):
        """Standard forward"""

        # Downsample
        initial_conv_out    = self.initial_conv(input)
        downsample_1_out = self.downsample_1(initial_conv_out)
        downsample_2_out = self.downsample_2(downsample_1_out)

        # Residual
        residual_blocks_out = self.residual_blocks(downsample_2_out)

        # Upsample
        upsample_2_out = self.upsample_2(residual_blocks_out, downsample_1_out)
        upsample_1_out = self.upsample_1(upsample_2_out, initial_conv_out)
        final_out = self.output_conv_naive(upsample_1_out)

        # Features
        # features = self.feature_projection(residual_blocks_out)

        # Return multiple final conv results
        return final_out, # features

class GTRainModel(nn.Module):
    def __init__(
            self, 
            ngf=64,
            n_blocks=9,
            norm_layer_type='batch',
            activation_func=torch.nn.LeakyReLU(negative_slope=0.10, inplace=True),
            upsample_mode='bilinear',
            init_type='kaiming'):
        """
        GT-Rain Model
        Parameters:
            ngf (int) -- the number of conv filters
            n_blocks (int) -- the number of deformable ResNet blocks
            norm_layer_type (str) -- 'batch', 'instance'
            activation_func (func) -- activation functions
            upsample_mode (str) -- 'transpose', 'bilinear'
            init_type (str) -- None, 'normal', 'xavier', 'kaiming', 'orthogonal'
        """
        super(GTRainModel, self).__init__()
        self.resnet = ResNetModified(
            input_nc=3, output_nc=3, ngf=ngf, 
            norm_layer=get_norm_layer(norm_layer_type),
            activation_func=activation_func,
            use_dropout=False, n_blocks=n_blocks, 
            padding_type='reflect',
            upsample_mode=upsample_mode)

        # Initialization
        if init_type:
            init_net(self.resnet, init_type=init_type)

    def forward(self, x):
        out_img = self.resnet(x)
        return out_img 

In [None]:
# Set model parameters
model_params = {
  'load_dir': 'path/to/model/ckpt', # Dir to load model weights
  'init_type': None, # Initialization type 
  'norm_layer_type': 'batch', # Normalization type
  'activation_func': torch.nn.LeakyReLU(negative_slope=0.10, inplace=True), # Activation function
  'upsample_mode': 'bilinear', # Mode for upsampling
  'ngf': 64,
  'n_blocks': 9}

# Make the model
model = GTRainModel(
  ngf=model_params['ngf'],
  n_blocks=model_params['n_blocks'],
  norm_layer_type=model_params['norm_layer_type'],
  activation_func=model_params['activation_func'],
  upsample_mode=model_params['upsample_mode'],
  init_type=model_params['init_type'])

# Load model checkpoint
checkpoint = torch.load(model_params['load_dir']) #, map_location=torch.device('cpu')
model.load_state_dict(checkpoint['state_dict'], strict=True)
model.cuda()
model.eval()

GTRainModel(
  (resnet): ResNetModified(
    (initial_conv): Sequential(
      (0): Conv2d(
        (activation_func): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv): Sequential(
          (0): ReflectionPad2d((3, 3, 3, 3))
          (1): Conv2d(3, 64, kernel_size=(7, 7), stride=(1, 1), bias=False)
          (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
      (1): Conv2d(
        (activation_func): LeakyReLU(negative_slope=0.1, inplace=True)
        (conv): Sequential(
          (0): ReflectionPad2d((1, 1, 1, 1))
          (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), bias=False)
          (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (downsample_1): Conv2d(
      (activation_func): LeakyReLU(negative_slope=0.1, inplace=True)
      (conv): Sequential(
        (0): ReflectionPad2d((1, 1, 1, 1))
        (1): Conv2d(64, 128, kernel_size=(3, 3), s

In [None]:
# Set parameters for the pipeline

pipeline_params = {
    "min_size": 256, # minimum size of height and width of image
    "optical_flow_motion_threshold": 5, # magnitude threshold for optical flow
    "psnr_illumination_high_dif_threshold": .5, # difference in input/gt vs output/gt psnr for high threshold 
    "psnr_illumination_low_dif_threshold": 2, #  difference in input/gt vs output/gt psnr for low threshold
    "psnr_illumination_high_threshold": 25, # high psnr threshold for illumination check
    "psnr_illumination_low_threshold": 20, # low psnr threshold for illumination check
    "psnr_movement_threshold": 20, # hard threshold on psnr of input
    "input_frame_num": 10, # number of input frames to check
    "overlap_percentage": .3, # the amount of overlap to throw away potential crops
    "optical_flow_window_size": 25, # window size for optical flow
    "mean_img_num": 15, # number of images to average for input optical flow (eliminates rain streak/snow flake)
    "mean_frame_skip": 2, # number of frames to skip when generating mean images (accelerates OF time)
    "gt_frame_skip": 3, # number of frames to skip when generating the gt of (accelrates OF time)
    "downsample_of": True, # whether or not to downsample the optical flow (accelerates OF time)
    "chromatic_var_threshold": .2, # threshold for chromatic variation property
    "sobel_blur_threshold": -.05, # threshold for bluriness check with sobel filter
    "sobel_ksize": 15, # kernel size for sobel filter bluriness check
    "sobel_scale": 1, # scale for sobel filter bluriness check
    "gaussian_ksize": 7, # kernel size for gaussian blur before sobel filtering
    "fft_blur_threshold": -.05, # threshold for bluriness check with fft filter
    "fft_cutoff": 30, # cutoff for the fft in pixels
  }

pipeline_params["se_movement_threshold"] = (10**(-.1*pipeline_params["psnr_movement_threshold"]))**.5

In [None]:
def get_optical_flows(video_path, input_frame_num, mean_img_num, mean_frame_skip):
    
    """
    Method for getting optical flow filter masks, also returns the degraded images and checks for corrupted video
    Parameters:
        video_path: path to input video with weather effects
        input_frame_num: number of input frames from video to take into account for pipeline
        mean_img_num: number of input frames to take the mean over to eliminate rain streaks/snow flakes
        mean_frame_skip: number of input frames to skip over when taking the OF of the averaged images
    Output:
        binary_mask: filter mask detecting pixels that are moving from optical flow, shape:(H, W, 1)
        degraded_imgs: stored frames for degraded weather effects, shape: (input_frame_num, H, W, 3)
        is_video: boolean determining whether or not video can be read (may be corrupted)
    """
    
    cap = cv2.VideoCapture(video_path)
    cap_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    img_queue = []
    degraded_imgs = []
    is_video = False
    
    i = 0
    first_run = True
    frame_skip = cap_frames//input_frame_num
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
        if i % frame_skip == 0:
            degraded_imgs.append(frame)
        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
        if (pipeline_params["downsample_of"]):
            h, w = frame.shape
            frame = cv2.resize(frame, (w//2, h//2))
        img_queue.append(frame)
        if i >= mean_img_num-1:
            if i % mean_frame_skip == 0:
                stacked_imgs = np.stack(img_queue)
                pixel_wise_means = np.mean(stacked_imgs, axis=0)
                if first_run:
                    first_run = False
                    is_video = True
                    binary_mask = np.zeros(frame.shape, dtype=np.dtype(bool))
                else:
                    flow = cv2.calcOpticalFlowFarneback(pixel_wise_means, prev_mean, None, 0.5, 3, pipeline_params["optical_flow_window_size"], 3, 5, 1.2, 0)
                    mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                    gt_of_mask = mag>pipeline_params["optical_flow_motion_threshold"]
                    binary_mask = np.bitwise_or(binary_mask, gt_of_mask)
                prev_mean = pixel_wise_means
            img_queue.pop(0)
        i += 1

    cap.release()
    cv2.destroyAllWindows()
    if not is_video:
        return None, None, is_video
    
    if pipeline_params["downsample_of"]:
        h, w = binary_mask.shape
        binary_mask = cv2.resize(binary_mask.astype(np.uint8), (w*2, h*2)).astype(bool)
    return binary_mask, degraded_imgs, is_video

def get_optical_flow_clean(video_path, gt_frame_skip):
    
    """
    Grabbing the optical flow filter mask of gt videos, also grabs the gt frame and a boolean for corrupted videos
    Parameters:
        video_path: path to gt video with no weather effects
        gt_frame_skip: number of frames to skip when taking the gt OF
    Output:
        gt_of_mask: binary filter mask denoting pixel movement in gt frame, shape: (H, W, 1)
        gt_img: gt frame with no weather effects, shape: (H, W, 3)
        is_video: boolean determining whether or not video can be read (may be corrupted)
    """
    
    cap = cv2.VideoCapture(video_path)
    cap_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    is_video = False
    
    i = 0
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        if i == 0:
            gt_img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB).astype(np.float32)/255.0
            prev = cv2.cvtColor(gt_img, cv2.COLOR_RGB2GRAY)
            is_video = True
        if i == gt_frame_skip:
            cur = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY).astype(np.float32)/255.0
            flow = cv2.calcOpticalFlowFarneback(cur, prev, None, 0.5, 3, pipeline_params["optical_flow_window_size"], 3, 5, 1.2, 0)
            mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
            gt_of_mask = mag>pipeline_params["optical_flow_motion_threshold"]
            break
        i += 1

    cap.release()
    cv2.destroyAllWindows()
    if not is_video:
        return None, None, is_video
    
    return gt_of_mask, gt_img, is_video

In [None]:
def cal_skyline(mask):
    """
    Helper function for sky segmentation
    Median filter and threshold of gradient for sky segmentation
    """
    h, w = mask.shape
    for i in range(w):
        raw = mask[:, i]
        after_median = medfilt(raw, 19)
        try:
            first_zero_index = np.where(after_median == 0)[0][0]
            first_one_index = np.where(after_median == 1)[0][0]
            if first_zero_index > 20:
                mask[first_one_index:first_zero_index, i] = 1
                mask[first_zero_index:, i] = 0
                mask[:first_one_index, i] = 0
        except:
            continue
    return mask


def get_sky_region_gradient(img):
    
    """
    Uses image gradients to segment the sky region out
    Output:
        a filter mask removing the sky portion of an image
    """
    
    img = (img*255).astype(np.uint8)

    h, w, _ = img.shape

    img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    img_gray = cv2.blur(img_gray, (9, 3))
    cv2.medianBlur(img_gray, 5)
    lap = cv2.Laplacian(img_gray, cv2.CV_8U)
    gradient_mask = (lap < 6).astype(np.uint8)

    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (9, 3))
    mask = cv2.morphologyEx(gradient_mask, cv2.MORPH_ERODE, kernel)
    mask = cal_skyline(mask)

    kernel_2 = np.ones((5,5), dtype=np.uint8)
    mask_refined = cv2.erode(mask, kernel_2, iterations=5)

    return mask_refined>0

In [None]:
def restore_degraded(degraded_imgs):
    """
    Restores degraded images using seed model
    Parameters:
        degraded_imgs: degraded images in shape: (input_frame_num, H, W, 3)
    Output:
        restored_imgs: restored images in shape: (input_frame_num, H, W, 3)
    """
    restored_imgs = []
    for degraded_img in degraded_imgs:
        model_input = torch.from_numpy(degraded_img).permute((2, 0, 1)) * 2 - 1
        model_input = torch.unsqueeze(model_input, 0).cuda()
        with torch.no_grad():
            output = (model(model_input)[0]* 0.5 + 0.5).squeeze().permute((1, 2, 0))
        restored_img = output.detach().cpu().numpy()
        restored_imgs.append(restored_img)
    model_input = None
    output = None
    gc.collect()
    torch.cuda.empty_cache()
    return restored_imgs

In [None]:
def generate_static_obj_masks(restored_imgs, gt_img):
    """
    Generates binary filter mask for detecting static object discrepancies between input and gt frames
    Parameter:
        restored_imgs: restored images, output from seed model, shape: (input_frame_num, H, W, 3)
        gt_img: gt frame, shape: (H, W, 3)
    Output:
        mask: binary filter mask determing which pixels have the presence of a static object discrepancy
    """
    mask = np.zeros((gt_img.shape[0], gt_img.shape[1]), dtype=np.dtype(bool)) 
    for idx in range(len(restored_imgs)):
        restored_img = restored_imgs[idx]
        frame_mask = (np.mean(np.abs(restored_img - gt_img), axis=2)>pipeline_params["se_movement_threshold"])
        mask = np.bitwise_or(mask, frame_mask)
    return mask

In [None]:
def chromatic_variation(degraded_img, gt_img):
    """
    Gets the binary mask denoting pixels that do not follow the chromatic variation constraint
        Reference Section "Filtering Block Two: Color Verification" in Section 4
    Parameters:
        degraded_img: a single degraded frame with weather effects
        gt_img: a gt frame without weather effects
    Output:
        binary mask denoting pixels that do not follow the chromatic variation constraint
            used to detect static object variations
    """
    diff_map_max = np.max(degraded_img-gt_img, axis=2)
    diff_map_min = np.min(degraded_img-gt_img, axis=2)
    diff_map = (diff_map_max - diff_map_min)
    threshold_map = diff_map > pipeline_params["chromatic_var_threshold"]
    kernel = np.ones((21,21), np.uint8)
    threshold_map_closed = cv2.morphologyEx(threshold_map.astype(np.uint8), cv2.MORPH_OPEN, kernel)
    
    return threshold_map_closed

In [None]:
def overlap_percent(XA1, YA1, XA2, YA2, XB1, YB1, XB2, YB2, SA, SB):
    """
    Helper function for cropping function
    Calculates the IOU of two crops
    Parameters:
        XA1, YA1, XA2, YA2: coordinates of first rectangle: left x, top y, right x, bottom y
        XB1, YB1, XB2, YB2: coordinates of second rectangle: left x, top y, right x, bottom y
    """
    SI = max(0, min(XA2, XB2) - max(XA1, XB1)) * max(0, min(YA2, YB2) - max(YA1, YB1))
    SU = SA + SB - SI
    return SI/SU


def allRectangle(matrix):
    
    """
    Calculates all possible rectangular crops bigger than "min_size", keeps overlap below "overlap_percentage"
    Parameters:
        matrix: binary validity map for crops
    Output:
        crop_dict: output crop dictionary in the form: crop_dict[(y, x)] = (height, width)
    """
    n = matrix.shape[1] 
    height = [0] * (n + 1)
    ans = 0
    crop_dict = {}
    for row_idx in range(matrix.shape[0]):
        for i in range(n):
            height[i] = height[i] + 1 if matrix[row_idx][i] == 1 else 0
        stack = [-1]
        for i in range(n + 1):
            while height[i] < height[stack[-1]]:
                h = height[stack.pop()]
                w = i - 1 - stack[-1]
                if h >= pipeline_params["min_size"] and w >= pipeline_params["min_size"]:
                    rect_row = row_idx -h + 1
                    rect_col = i
                    if (rect_row, rect_col) in crop_dict:
                        dict_val = crop_dict[(rect_row, rect_col)]
                        if h*w > dict_val[0]*dict_val[1]:
                            crop_dict[(rect_row, rect_col)] = (h, w)
                    else:
                        overlap_bool = False
                        for (row2, col2) in crop_dict:
                            (h2, w2) = crop_dict[(row2, col2)]
                            overlap = overlap_percent(rect_col-w, rect_row, rect_col, rect_row+h,
                                                     col2-w2, row2, col2, row2+h2, h*w, h2*w2)
                            if overlap > pipeline_params["overlap_percentage"]:
                                if h*w > h2*w2:
                                    crop_dict.pop((row2,col2))
                                else:
                                    overlap_bool = True
                                break
                        if not overlap_bool:
                            crop_dict[(rect_row, rect_col)] = (h, w)
            stack.append(i)
    return crop_dict

In [None]:
def psnr_illumination_check(degraded_imgs, restored_imgs, gt_img, crop_dict):
    """
    Check for illumination shift using hysteresis of psnr of input/gt vs. psnr of restored/gt using seed model
    Parameters:
        degraded_imgs: input images of shape: (input_frame_num, H, W, 3)
        restored_imgs: restored images from seed model of shape: (input_frame_num, H, W, 3)
        gt_img: gt image of shape: (H, W, 3)
        crop_dict: crop dictionary in the form: crop_dict[(y, x)] = (height, width)
    Output:
        accepted_list: list of all accepted scenes, 
            element is tuple of form (degraded/gt psnr, restored/gt psnr, y, x, h, w)
            stores psnr values and crop locations
    """
    accepted_list = []
    crop_psnr_in_dict = {}
    crop_psnr_out_dict = {}
    for idx in range(len(degraded_imgs)):
        degraded_img = degraded_imgs[idx]
        restored_img = restored_imgs[idx]

        for (row, col) in crop_dict:
            (h, w) = crop_dict[(row, col)]
            gt_crop = gt_img[row:row+h, col-w:col, :]
            degraded_crop = degraded_img[row:row+h, col-w:col, :]
            restored_crop = restored_img[row:row+h, col-w:col, :]
            psnr_in = psnr(degraded_crop, gt_crop)
            psnr_out = psnr(restored_crop, gt_crop)
            if (row, col) in crop_psnr_in_dict:
                crop_psnr_in_dict[(row, col)] += psnr_in
                crop_psnr_out_dict[(row, col)] += psnr_out
            else:
                crop_psnr_in_dict[(row, col)] = psnr_in
                crop_psnr_out_dict[(row, col)] = psnr_out
    for (row, col) in crop_dict:
        (h, w) = crop_dict[(row, col)]
        psnr_in = crop_psnr_in_dict[(row, col)]/len(degraded_imgs)
        psnr_out = crop_psnr_out_dict[(row, col)]/len(degraded_imgs)
        if (psnr_out > pipeline_params["psnr_illumination_high_threshold"]
           and psnr_out-psnr_in > pipeline_params["psnr_illumination_high_dif_threshold"]):
            accepted_list.append((psnr_in, psnr_out, row, col-w, h, w))
            continue
        if (psnr_out > pipeline_params["psnr_illumination_low_threshold"]
           and psnr_out-psnr_in > pipeline_params["psnr_illumination_low_dif_threshold"]):
            accepted_list.append((psnr_in, psnr_out, row, col-w, h, w))
    return accepted_list
        

In [None]:
def calcMetricSobel(degraded, gt, ksize, scale, gaussian_ksize):
    """
    Calculates a metric for bluriness based on the Sobel gradient filter
    Parameters:
        degraded: degraded image
        gt: gt image
        ksize: kernel size of Sobel filter
        scale: scale of Sobel filter
        gaussian_ksize: kernel size of Gaussian blur filter before Sobel filter
    Output:
        Gives metric based on magnitude of gradients, normalized
    """
    gt = cv2.cvtColor(gt, cv2.COLOR_BGR2GRAY)
    degraded = cv2.cvtColor(degraded, cv2.COLOR_BGR2GRAY)

    gt_blur = cv2.GaussianBlur(gt, (gaussian_ksize, gaussian_ksize), sigmaX=0)
    gt_var_x = np.square(cv2.Sobel(gt_blur, cv2.CV_64F, 1, 0, ksize=ksize, scale=scale))
    gt_var_y = np.square(cv2.Sobel(gt_blur, cv2.CV_64F, 0, 1, ksize=ksize, scale=scale))
    gt_var = np.sqrt(gt_var_x + gt_var_y).mean()
    degraded_blur = cv2.GaussianBlur(degraded, (gaussian_ksize, gaussian_ksize), sigmaX=0)
    degraded_var_x = np.square(cv2.Sobel(degraded_blur, cv2.CV_64F, 1, 0, ksize=ksize, scale=scale))
    degraded_var_y = np.square(cv2.Sobel(degraded_blur, cv2.CV_64F, 0, 1, ksize=ksize, scale=scale))
    degraded_var = np.sqrt(degraded_var_x + degraded_var_y).mean()
    
    return (gt_var - degraded_var)/gt_var

In [None]:
# FFT Metric from https://pyimagesearch.com/2020/06/15/opencv-fast-fourier-transform-fft-for-blur-detection-in-images-and-video-streams/
def calcMetricFFT(degraded, gt, cutoff):
    """
    Calculates a metric for bluriness based on the FFT
    Parameters:
        degraded: degraded image
        gt: gt image
        cutoff: cutoff to use for low-pass filter
    Output:
        Gives metric based on output magnitude after low-pass filtering, normalized
    """
    gt = cv2.cvtColor(gt, cv2.COLOR_RGB2GRAY)
    degraded = cv2.cvtColor(degraded, cv2.COLOR_RGB2GRAY)
    (h, w) = degraded.shape
    (cX, cY) = (int(w / 2.0), int(h / 2.0))
    
    gt_fft = np.fft.fft2(gt)
    gt_fftShift = np.fft.fftshift(gt_fft)
    gt_fftShift[cY - cutoff:cY + cutoff, cX - cutoff:cX + cutoff] = 0
    gt_fftShift = np.fft.ifftshift(gt_fftShift)
    gt_recon = np.fft.ifft2(gt_fftShift)
    gt_mag = 20*np.log(np.abs(gt_recon)).mean()
    
    degraded_fft = np.fft.fft2(degraded)
    degraded_fftShift = np.fft.fftshift(degraded_fft)
    degraded_fftShift[cY - cutoff:cY + cutoff, cX - cutoff:cX + cutoff] = 0
    degraded_fftShift = np.fft.ifftshift(degraded_fftShift)
    degraded_recon = np.fft.ifft2(degraded_fftShift)
    degraded_mag = 20*np.log(np.abs(degraded_recon)).mean()
    
    return (gt_mag-degraded_mag)/gt_mag

In [None]:
def multi_scatter_check(degraded_imgs, gt_img, accepted_list):
    """
    Checks to see if degraded and gt frames conform to the multi-scatter constraint
    Reference Section "Filtering Block Three: Multi-scatter Verification" in Section 4
    Parameters:
        degraded_imgs: degraded frames with weather effects
        gt_img: gt frames without weather effects
        accepted_list: list of accepted crops output from the illumination shift check
    Output:
        filtered_accepted_list: list of accepted crops with same elements,
            but with crops violating multi-scatter constrain removed
    """
    filtered_accepted_list = []
    for idx in range(len(accepted_list)):
        (psnr_in, psnr_out, row, col, h, w) = accepted_list[i]

        degraded_expanded = np.expand_dims((degraded_imgs[0][row:row+h, col:col+w, :]*255).astype(np.uint8), axis=0)
        for i in range(1, len(degraded_imgs)):
            next_degraded = np.expand_dims((degraded_imgs[i][row:row+h, col:col+w, :]*255).astype(np.uint8), axis=0)
            degraded_expanded = np.vstack((degraded_expanded, next_degraded))
        degraded_avg = np.mean(degraded_expanded, axis=0).astype(np.uint8)

        gt_img_uint8 = (gt_img[row:row+h, col:col+w, :]*255).astype(np.uint8)
        metric1 = calcMetricSobel(degraded_avg, gt_img_uint8, 
                                  pipeline_params["sobel_ksize"],
                                  pipeline_params["sobel_scale"],
                                  pipeline_params["gaussian_ksize"])
        metric2 = calcMetricFFT(degraded_avg, gt_img_uint8, pipeline_params["fft_cutoff"])
        
        if metric1 > pipeline_params["sobel_blur_threshold"] and metric2 > pipeline_params["fft_blur_threshold"]:
            filtered_accepted_list.append((psnr_in, psnr_out, row, col, h, w))
    return filtered_accepted_list

In [None]:
def save_scenes(accepted_list, video_in, gt_video, gt_img, out_dir):
    """
    saves frames and json dictionary to an output directory
    Parameters:
        accepted_list: list of all accepted scenes, 
            element is tuple of form (degraded/gt psnr, restored/gt psnr, y, x, h, w)
        video_in: input video path with weather degradations
        gt_video: gt video path with no weather degradations
        gt_img: gt frame with no weather degradations
        out_dir: output directory with location to save the json and frames
    Saves:
        logs_dict: json file with information regarding scene:
            gt_video: video path where gt frame was taken from
            degraded_video: video path where degraded frames were taken from
            crop: (y, x, h, w) location for crop
            psnr_in: degraded/gt psnr
            psnr_out: restored/gt psnr
        Frames Directory Structure:
            out_dir
                gt.png
                degraded_0.png
                degraded_1.png
                    ...
                degraded_n.png
    """
    for i in range(len(accepted_list)):
        out_dir_path = f"{out_dir}_{i}/"
        os.makedirs(out_dir_path, exist_ok=True)
        
        logs_dict = {}
        (psnr_in, psnr_out, row, col, h, w) = accepted_list[i]
        logs_dict['gt_video'] = gt_video
        logs_dict['degraded_video'] = video_in
        logs_dict['crop'] = (row, col, h, w)
        logs_dict['psnr_in'] = psnr_in
        logs_dict['psnr_out'] = psnr_out
        with open(f"{out_dir}_{i}/logs_dict.json", "w") as outfile:
            json.dump(logs_dict, outfile)
    
    cap = cv2.VideoCapture(video_in)
    cap_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    idx = 0
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        for i in range(len(accepted_list)):
            (psnr_in, psnr_out, row, col, h, w) = accepted_list[i]
            out_dir_path = f"{out_dir}_{i}/"
            Image.fromarray((gt_img[row:row+h, col:col+w, :]*255).astype(np.uint8)).save(f'{out_dir_path}gt.png')
            degraded = frame[row:row+h, col:col+w, :]
            cv2.imwrite(f'{out_dir_path}degraded_{idx}.png', degraded)
        idx += 1

    cap.release()
    cv2.destroyAllWindows()

In [None]:
"""
Main pipeline loop
"""

"""
Specify folder name with location of videos
Video directory structure should be:
folder_name
    scene_name
        gt
            (gt_video_name).mp4
        degraded
            (degraded_video_name).mp4
"""
folder_name = "path/to/folder_name"
scene_paths = natsorted(glob(f"{folder_name}/*"))
for scene_path in tqdm(scene_paths[:]):
    scene_name = scene_path.split('/')[-1]
    print(scene_name)
    # Find all degraded and gt video paths
    video_files_degraded = natsorted(glob(f"{scene_path}/degraded/*.mp4"))
    video_files_gt = natsorted(glob(f"{scene_path}/gt/*.mp4"))
    # Found scene flag
    found_scene = False
    # Loop through all degraded videos
    for degraded_video in video_files_degraded[-1::-3]:
        if found_scene:
            break
        # Get optical flow masks and degraded frames
        degraded_video_name = degraded_video.split('/')[-1][:-4]
        degraded_of, degraded_imgs, is_video = get_optical_flows(degraded_video, 
                                                 pipeline_params["input_frame_num"], 
                                                 pipeline_params["mean_img_num"], 
                                                 pipeline_params["mean_frame_skip"])
        # Continue loop if video is corrupted
        if not is_video:
            continue

        # Get restored images
        restored_imgs = restore_degraded(degraded_imgs)
        
        # Loop through gt videos
        for gt_video in video_files_gt[::3]:
            if found_scene:
                break
            
            # Get optical flow masks and gt frame
            gt_video_name = gt_video.split('/')[-1][:-4]
            gt_of, gt_img, is_video = get_optical_flow_clean(gt_video, pipeline_params["gt_frame_skip"])

            # Continue loop if video is corrupted or gt frame is different size than rainy frame
            if not is_video:
                continue
            if ((not gt_img.shape[1] == degraded_imgs[0].shape[1]) or (not gt_img.shape[0] == degraded_imgs[0].shape[0])):
                continue

            # Get static object discrepancy binary mask
            static_obj_mask = generate_static_obj_masks(restored_imgs, gt_img)

            # Get sky segmentation binary mask
            sky_mask = get_sky_region_gradient(gt_img)
            
            # Get chromatic variation binary mask
            chrom_var_mask = chromatic_variation(degraded_imgs[0], gt_img)

            # Union of filter masks
            of_mask = np.bitwise_or(degraded_of, gt_of)
            discrepancy_mask = np.bitwise_or(of_mask, static_obj_mask)
            discrepancy_mask_final = np.bitwise_or(discrepancy_mask, chrom_var_mask)
            final_mask = np.bitwise_or(discrepancy_mask_final, sky_mask)

            # Smooth mask then find all suitable rectangular crops
            inv_binary_mask = np.invert(final_mask).astype(np.uint8)
            kernel = np.ones((9,9), np.uint8)
            inv_binary_mask = cv2.morphologyEx(inv_binary_mask, cv2.MORPH_CLOSE, kernel)
            crop_dict = allRectangle(inv_binary_mask)

            # If no crops available, continue with next loop iteration
            if (len(crop_dict) == 0):
                continue

            # Find all accepted crops based on illumination shift check
            accepted_list = psnr_illumination_check(degraded_imgs, restored_imgs, gt_img, crop_dict)
            # Find all accepted crops based on multi scatter check
            accepted_list = multi_scatter_check(degraded_imgs, gt_img, accepted_list)

            # If no crops accepted, continue with next loop iteration
            if (len(accepted_list) == 0):
                continue
            else:
                # If crops accepted, save frames and json dictionary
                found_scene = True
                save_scenes(accepted_list, degraded_video, gt_video, gt_img, f"{folder_name}_out/{scene_name}")
                print("Found scene!")

  0%|          | 0/1 [00:00<?, ?it/s]

206_0
