In [3]:
import os
import time
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torchvision import transforms
from PIL import Image

In [13]:
def concat(layers):
    return torch.cat(layers, dim=1)

In [14]:
class DecomNet(nn.Module):
    def init(self,layer_num,channel=64,kernel_size=3):
        super(DecomNet, self).init()
        self.layer_num = layer_num
        self.shallow_feature_extraction=nn.Conv2d(4,channel,kernel_size=kernel_size*3,padding=kernel_size//2)
        self.activated_layers = nn.ModuleList([nn.Conv2d(channel,channel,kernel_size=kernel_size,padding=kernel_size//2) for i in range(layer_num)])
        self.recon_layers=nn.Conv2d(channel,4,kernel_size=kernel_size,padding=kernel_size//2)
    def forward(self,input_im):
        input_max, _ = torch.max(input_im, dim=1, keepdim=True)
        input_im = concat([input_im, input_max])
        conv=self.shallow_feature_extraction(input_im)
        for i in range(self.layer_num):
            conv = F.relu(self.activated_layers[i](conv))
        conv=self.recon_layers(conv)
        R=torch.sigmoid(conv[:,0:3])
        L=torch.sigmoid(conv[:,3:4])
        return R,L
        

In [17]:
class RelightNet(nn.Module):
    def __init__(self, channel=64, kernel_size=3):
        super(RelightNet, self).__init__()
        # Convolutional layers for down-sampling (encoding)
        self.conv0 = nn.Conv2d(4, channel, kernel_size, padding=kernel_size // 2)
        self.conv1 = nn.Conv2d(channel, channel, kernel_size, stride=2, padding=kernel_size // 2)
        self.conv2 = nn.Conv2d(channel, channel, kernel_size, stride=2, padding=kernel_size // 2)
        self.conv3 = nn.Conv2d(channel, channel, kernel_size, stride=2, padding=kernel_size // 2)
        # Deconvolutional layers for up-sampling (decoding)
        self.deconv1 = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size // 2)
        self.deconv2 = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size // 2)
        self.deconv3 = nn.Conv2d(channel, channel, kernel_size, padding=kernel_size // 2)
        # Fusion layer to combine features from different levels
        self.feature_fusion = nn.Conv2d(channel * 3, channel, 1, padding=0)
        # Output layer to generate the final enhanced illumination map
        self.output_layer = nn.Conv2d(channel, 1, 3, padding=1)

    def forward(self, input_L, input_R):
        # Concatenate reflectance map and illumination map
        input_im = concat([input_R, input_L])
        # Encoding path: apply down-sampling convolutions
        conv0 = self.conv0(input_im)
        conv1 = F.relu(self.conv1(conv0))
        conv2 = F.relu(self.conv2(conv1))
        conv3 = F.relu(self.conv3(conv2))
        # Decoding path: up-sample and combine with previous layers
        up1 = F.interpolate(conv3, size=(conv2.shape[2], conv2.shape[3]), mode='nearest')
        deconv1 = F.relu(self.deconv1(up1) + conv2)
        up2 = F.interpolate(deconv1, size=(conv1.shape[2], conv1.shape[3]), mode='nearest')
        deconv2 = F.relu(self.deconv2(up2) + conv1)
        up3 = F.interpolate(deconv2, size=(conv0.shape[2], conv0.shape[3]), mode='nearest')
        deconv3 = F.relu(self.deconv3(up3) + conv0)
        
        # Resize feature maps to match the output size and concatenate
        deconv1_resize = F.interpolate(deconv1, size=(deconv3.shape[2], deconv3.shape[3]), mode='nearest')
        deconv2_resize = F.interpolate(deconv2, size=(deconv3.shape[2], deconv3.shape[3]), mode='nearest')
        feature_gather = concat([deconv1_resize, deconv2_resize, deconv3])
        # Fuse features from different levels
        feature_fusion = self.feature_fusion(feature_gather)
        # Generate the enhanced illumination map
        output = self.output_layer(feature_fusion)
        return output
