**Spatio-temporal propagation and reconstruction for low-light video enhancement**

In [None]:
#Importing the necessary libraries

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.models as models

In [None]:
#Input module

def load_video_frames(video_path):
    cap = cv2.VideoCapture(video_path)
    frames = []
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    return frames

The Lucas-Kanade method was used in computing the optical flow

In [None]:
#Flow estimation module

def flow_estimation(prev_frame, curr_frame):

    # Convert frames to grayscale
    prev_gray = cv2.cvtColor(prev_frame, cv2.COLOR_BGR2GRAY)
    curr_gray = cv2.cvtColor(curr_frame, cv2.COLOR_BGR2GRAY)

    # Parameters for Lucas-Kanade optical flow
    lk_params = dict(winSize=(15, 15),
                     maxLevel=2,
                     criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03))

    # Calculate optical flow using Lucas-Kanade method
    flow = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, None, None, **lk_params)

    # Select good points
    good_prev = flow[0].reshape(-1, 2)
    good_curr = flow[1].reshape(-1, 2)

    return good_prev, good_curr

In [None]:
# Optical flow propagation and reconstruction

def flow_propagation_reconstruction(prev_frame, curr_frame, good_prev, good_curr):
    # Parameters for propagation and reconstruction
    alpha = 0.3
    beta = 0.9

    # Calculate backward flow
    backward_flow = cv2.calcOpticalFlowPyrLK(curr_frame, prev_frame, None, None, **lk_params)[0].reshape(-1, 2)

    # Forward flow propagation
    propagated_flow = good_curr + backward_flow - good_prev

    # Temporal consistency
    reconstructed_flow = alpha * propagated_flow + (1 - alpha) * backward_flow

    # Spatial propagation
    warped_curr_frame = cv2.remap(curr_frame, reconstructed_flow, None, cv2.INTER_LINEAR)
    propagated_frame = beta * warped_curr_frame + (1 - beta) * curr_frame

    return propagated_frame


In [None]:
# Pyramid Recursive Residual Dense Subnet

class PRDSubnet(nn.Module):
    def __init__(self):
        super(PRDSubnet, self).__init__()
        # Define your architecture here
        # For example, you can use a series of Conv2d, BatchNorm2d, and ReLU layers
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Add more layers as needed...

    def forward(self, x):
        # Implement forward pass of the subnet
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)

        # Continue with other layers and sub-modules...

        return x

In [None]:
# Spatio-temporal Feature Reconstruction Subnet

class STRSubnet(nn.Module):
    def __init__(self):
        super(STRSubnet, self).__init__()
        # Define your architecture here
        # For example, you can use a combination of Conv3d, BatchNorm3d, and ReLU layers
        self.conv3d_1 = nn.Conv3d(in_channels, out_channels, kernel_size, stride, padding)
        self.bn3d_1 = nn.BatchNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        # Add more layers as needed...

    def forward(self, x):
        # Implement forward pass of the subnet
        x = self.conv3d_1(x)
        x = self.bn3d_1(x)
        x = self.relu(x)

        # Continue with other layers and sub-modules...

        return x