In [None]:
import tensorflow as tf
import matplotlib.pyplot as plt
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import torch.nn as nn
import torch
from PIL import Image
from torchvision import transforms
import os
import cv2
import json
from google.colab import files
from zipfile import ZipFile

In [None]:
from google.colab import drive
drive.mount('/content/drive')
image_dir = '/content/drive/My Drive/UOD/DUO/train'

In [None]:
#train_image_10
from google.colab import drive
drive.mount('/content/drive')
image_dir = '/content/drive/My Drive/UOD/DUO/train_image_10'
ground_image_dir = '/content/drive/My Drive/UOD/DUO/enhanced_image_10'

**JOINT COMMON ENCODER:**

In [None]:
class CommonEncoder(nn.Module):
    def __init__(self):
        super(CommonEncoder, self).__init__()

        self.C1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1),
            nn.ReLU()
        )

        self.C2 = self._make_residual_level(in_channels=64, out_channels=64, stride=4)
        self.C3 = self._make_residual_level(in_channels=64, out_channels=64, stride=8)
        self.C4 = self._make_residual_level(in_channels=64, out_channels=64, stride=16)
        self.C5 = self._make_residual_level(in_channels=64, out_channels=64, stride=32)

        '''self.C1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
            nn.ReLU()
        )

        self.C2 = self._make_residual_level(in_channels=64, out_channels=64, stride=1)
        self.C3 = self._make_residual_level(in_channels=64, out_channels=64, stride=2)
        self.C4 = self._make_residual_level(in_channels=64, out_channels=64, stride=4)
        self.C5 = self._make_residual_level(in_channels=64, out_channels=64, stride=8)'''

    def _make_residual_level(self, in_channels, out_channels,stride):
        layers = []
        for _ in range(4):
            layers.append(ResidualBlock(in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.C1(x)
        c2_out = self.C2(x)
        c3_out = self.C3(c2_out)
        c4_out = self.C4(c3_out)
        c5_out = self.C5(c4_out)
        return c2_out,c3_out,c4_out,c5_out

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResidualBlock, self).__init__()

        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3,stride=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,stride=1)

    def forward(self, x):
        residual = x
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        #x += residual
        x = self.relu(x)
        return x


In [None]:
class ConvolutionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ConvolutionBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        out = self.relu(self.conv1(x))
        out = self.relu(self.conv2(out))
        out = self.relu(self.conv3(out))
        return out

class CommonEncoder(nn.Module):
    def __init__(self):
        super(CommonEncoder, self).__init__()
        self.C1 = ConvolutionBlock(in_channels=3, out_channels=64)
        self.C2 = ConvolutionBlock(in_channels=64, out_channels=64)

    def forward(self, x):
        out = self.C1(x)
        out = self.C2(out)
        return out

**FEATURE RECONSTRUCTION BLOCK:**

In [None]:
class FeatureReconstructionBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(FeatureReconstructionBlock, self).__init__()

        # Feature reconstruction block
        self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1)
        self.conv4 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=5, padding=2)
        self.conv5 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3, padding=1)
        self.conv6 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=7, padding=3)
        self.conv7 = nn.Conv2d(out_channels * 4, 3 , kernel_size=3, padding=1)

        '''self.conv0 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv1 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3)
        self.conv3 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3)
        self.conv4 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=5)
        self.conv5 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=3)
        self.conv6 = nn.Conv2d(out_channels * 2, out_channels, kernel_size=7)
        self.conv7 = nn.Conv2d(out_channels * 4, 3 , kernel_size=3)'''

    def forward(self, x):
        # Feature reconstruction
        conv0_out = self.conv0(x)
        conv1_out = self.conv1(conv0_out)
        conv2_out = self.conv2(conv1_out)
        conv3_concat = self.conv3(torch.cat((conv1_out, conv2_out), dim=1))
        conv4_out = self.conv4(torch.cat((conv2_out, conv3_concat), dim=1))
        conv5_concat = self.conv5(torch.cat((conv2_out, conv4_out), dim=1))
        conv6_out = self.conv6(torch.cat((conv4_out, conv5_concat), dim=1))
        conv7_concat = self.conv7(torch.cat((conv1_out, conv2_out, conv4_out, conv6_out), dim=1))

        return conv7_concat


**COLOR ADJUSTMENT BLOCK:**

In [None]:
class ColorAdjustmentBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
        super(ColorAdjustmentBlock, self).__init__()

        self.color_conv1 = nn.Conv2d(3 , out_channels, kernel_size=3, padding=1)
        self.color_conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        self.color_conv3 = nn.Conv2d(out_channels, 3 , kernel_size=3, padding=1)
        self.sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU()

  def forward(self, x):
        #color_out = self.color_conv1(torch.cat((conv1_out, conv2_out, conv4_out, conv6_out), dim=1))
        color_out= self.color_conv1(x)
        color_out = self.relu(color_out)
        color_out = self.color_conv2(color_out)
        color_out = self.relu(color_out)
        color_out = self.color_conv3(color_out)
        color_out = self.relu(color_out)

        output = self.sigmoid(color_out)
        return output


**LOSS FUNCTION**

In [None]:
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, c1_lambda, c_lambda, m_lambda, m_g_lambda):
        # Pixel-wise Mean Square Error (MSE) Loss
        mse_loss = F.mse_loss(c_lambda, c1_lambda)

        # Contrast Adjustment Loss
        contrast_loss = torch.norm(m_lambda - m_g_lambda, p=2)

        # Gradient Loss
        grad_m_lambda = torch.autograd.grad(contrast_loss, m_lambda, create_graph=True)[0]
        grad_m_g_lambda = torch.autograd.grad(contrast_loss, m_g_lambda, create_graph=True)[0]

        # Squared 2-norm of the element-wise absolute difference
        gradient_loss = torch.norm(torch.abs(grad_m_lambda) - torch.abs(grad_m_g_lambda), p=2)

        # Total Loss
        total_loss = mse_loss + contrast_loss + gradient_loss

        return total_loss

In [None]:
encoder = CommonEncoder()
feature_recons = FeatureReconstructionBlock(in_channels=64, out_channels=64)
color_adjus = ColorAdjustmentBlock(in_channels=256, out_channels=256)
loss_function = CustomLoss()
optimizer = torch.optim.Adam(encoder.parameters(), lr=0.001)

if not os.path.exists(image_dir):
    print(f"Directory {image_dir} does not exist.")
else:
    image_files = [filename for filename in os.listdir(image_dir) if filename.endswith(".jpg") or filename.endswith(".jpeg")]

    for epoch in range(100):
        total_loss = 0  # Initialize total loss for the epoch

        for idx, filename in enumerate(image_files[:4]):
            image_path = os.path.join(image_dir, filename)
            input_image = Image.open(image_path)
            # No data transformation for input image
            input_tensor = transforms.ToTensor()(input_image).unsqueeze(0)
            #input_tensor = F.interpolate(input_tensor, size=(224, 224), mode='bilinear', align_corners=False)
            print(f"input shape for {filename}: {input_tensor.shape}")

            c2,c3,c4,c5 = encoder(input_tensor)
            print(f"common encoder Output shape for {filename}: {c2.shape}")
            recons_output = feature_recons(c2)
            print(f"Feature Reconstruction Output shape for {filename}: {(recons_output).shape}")
            recons_output = F.interpolate(recons_output, size=(640, 640), mode='bilinear', align_corners=False)
            colorAdj_output = color_adjus(recons_output)
            print(f"Color Adjustment Output shape for {filename}: {(colorAdj_output).shape}")

            # Calculate c_lambda
            #recons_output = F.interpolate(recons_output, size=(640, 640), mode='bilinear', align_corners=False)
            c_lambda = (recons_output * input_tensor) - recons_output + 1

            enhanced_output = c_lambda * colorAdj_output
            #enhanced_output = c_lambda  - black with green color
            #enhanced_output = colorAdj_output - grey color

            ground_truth_image_path = os.path.join(ground_image_dir, filename)
            ground_truth_image = Image.open(ground_truth_image_path)
            # No data transformation for ground truth image
            ground_truth_tensor = transforms.ToTensor()(ground_truth_image).unsqueeze(0)
            #ground_truth_tensor = F.interpolate(ground_truth_tensor, size=(224, 224), mode='bilinear', align_corners=False)
            print(f"gt shape for {filename}: {ground_truth_tensor.shape}")

            c2_g,c3_g,c4_g,c5_g = encoder(ground_truth_tensor)
            recons_output_gt = feature_recons(c2_g)

            recons_output_gt = F.interpolate(recons_output_gt, size=(640, 640), mode='bilinear', align_corners=False)
            loss = loss_function(ground_truth_tensor, c_lambda, colorAdj_output, recons_output_gt)
            total_loss += loss.item()

        # Calculate average loss for the epoch
        average_loss = total_loss / len(image_files)

        print(f'Epoch {epoch + 1}, Average Loss: {average_loss}')

        # Perform optimization once after processing all images in the epoch
        optimizer.zero_grad()
        #torch.tensor(average_loss, requires_grad=True).backward()
        loss.backward()
        optimizer.step()

In [None]:
input_image = TF.to_pil_image(input_tensor.squeeze())
enhanced_image = TF.to_pil_image(enhanced_output.squeeze().detach())

# Display the images
fig, axes = plt.subplots(1, 2, figsize=(12, 6))
axes[0].imshow(input_image)
axes[0].set_title('Input Image')
axes[1].imshow(enhanced_image)
axes[1].set_title('Enhanced Image')