In [25]:
import import_ipynb
from ConsistencyIndexes import *
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from IPython.display import clear_output
import torch.optim as optim
from torch.nn.functional import mse_loss
from torchvision import transforms
from torch import nn

In [26]:
cap = open_vid("VDB/WID.mp4")
org = get_frames(cap)
cap = open_vid("Cartoonized/WID_toon.mp4")
car = get_frames(cap)

In [27]:
# Function to create convolutional kernels and weights
def create_kernels_weights(num_kernels, kernel_size):
    # Create kernels with 3 output channels instead of 1
    kernels = [torch.rand(3, 3, kernel_size, kernel_size, requires_grad=True) for _ in range(num_kernels)]
    weights = [torch.randn(1) for _ in range(num_kernels)]
    return kernels, weights

# Function to perform convolution and compute the weighted sum
def apply_kernels(input_image, kernels, weights):
    # Perform convolution with 3-channel output and sum results
    output = sum(weights[i] * F.conv2d(input_image, kernels[i], padding=kernels[i].shape[-1]//2) for i in range(len(kernels)))
    return output

# Function to apply kernels and display the result
def predict_and_display(input_image, kernels, weights):
    input_image = torch.tensor(input_image, dtype=torch.float32).permute(2,0,1).unsqueeze(0)
    with torch.no_grad():
        output_image = apply_kernels(input_image, kernels, weights)
    return output_image.squeeze(0).permute(1,2,0).cpu().numpy().astype(np.uint8)

# Main function
def train_model(input_images, target_images, num_kernels, kernel_size=3, epochs=100, lr=0.01):
    # Normalize input and target images (to 0-1)
    input_images = [img / 255.0 for img in input_images]
    target_images = [img / 255.0 for img in target_images]    
    # Convert input and target images to tensors
    input_images = torch.stack([torch.tensor(img, dtype=torch.float32) for img in input_images])
    target_images = torch.stack([torch.tensor(img, dtype=torch.float32) for img in target_images])  
    #Initialize kernels and weights
    kernels, weights = create_kernels_weights(num_kernels, kernel_size)
    # Define optimizer
    optimizer = optim.Adam(kernels+weights, lr=lr)
    # Training loop
    for epoch in range(epochs):
        total_loss = 0.0
        for input_image, target_image in zip(input_images, target_images):
            input_image = input_image.permute(2,0,1).unsqueeze(0)  # Add batch dimension
            target_image = target_image.permute(2,0,1).unsqueeze(0)         
            optimizer.zero_grad()   
            output_image = apply_kernels(input_image, kernels, weights)
            loss = mse_loss(output_image, target_image)
            loss.backward()
            optimizer.step()          
            total_loss += loss.item()       
        # Print epoch and loss
        clear_output(wait=True)
        print(f'Epoch [{epoch+1}/{epochs}], Loss: {total_loss/len(input_images):.4f}')   
    print('Training complete.')
    return kernels, weights

In [None]:
limit = 20
# Example usage:
# Assuming input_images and target_images are lists of PyTorch tensors with shape [C, H, W]
kernels, weights = train_model(org[:limit], car[:limit], num_kernels=5,epochs=50, lr=1e-3)

Epoch [14/50], Loss: 22.1442


In [29]:
Predicted = predict_and_display(org[0],kernels,weights)
display_frame(Predicted)