In [None]:
import torch
import torch.nn as nn

def pixel_loss(output, target):
    c, h, w = target.size(1), target.size(2), target.size(3)
    loss = nn.MSELoss(reduction='sum')(output, target) / (c * h * w)
    return loss

def perceptual_loss(output, target, vgg_model):

    output_features = vgg_model(output)
    target_features = vgg_model(target)

    loss = 0
    for i in range(len(output_features)):
        for j in range(len(output_features[i])):
            c_i_j, h_i_j, w_i_j = output_features[i][j].size(1), output_features[i][j].size(2), output_features[i][j].size(3)
            loss += nn.L1Loss(reduction='sum')(output_features[i][j], target_features[i][j]) / (c_i_j * h_i_j * w_i_j)

    return loss

def lr_average_loss(output, target):
    
    c, h, w = target.size(1), target.size(2), target.size(3)
    avg_pool = nn.AvgPool2d(kernel_size=3, stride=1, padding=1)

    I_b = torch.norm(avg_pool(target) - torch.flip(avg_pool(target), dims=[3]), p=1, dim=1)
    I_hat = torch.norm(avg_pool(output) - torch.flip(avg_pool(output), dims=[3]), p=1, dim=1)

    loss = nn.L1Loss(reduction='sum')(I_b, I_hat) / (c * h * w)
    return loss

def w_smooth_loss(output, target):
    c, h, w = target.size(1), target.size(2), target.size(3)
    a_H = 1 - torch.norm(target[:, :, :-1, :] - target[:, :, 1:, :], p=1, dim=1)
    a_W = 1 - torch.norm(target[:, :, :, :-1] - target[:, :, :, 1:], p=1, dim=1)
    d_H = torch.norm(output[:, :, :-1, :] - output[:, :, 1:, :], p=1, dim=1)
    d_W = torch.norm(output[:, :, :, :-1] - output[:, :, :, 1:], p=1, dim=1)

    loss = nn.L1Loss(reduction='sum')(a_H * d_H + a_W * d_W) / (c * h * w)
    return loss

In [None]:
import cv2
import torch
import numpy as np
from torchvision import transforms

class SILPAutoencoder(nn.Module):
    def __init__(self, input_channels=1, hidden_dim=500):
        super(SILPAutoencoder, self).__init__()
        self.laplacian_prior = LaplacianPriorSubnetwork()

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(input_channels, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            LRMatchModule(64, 64),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            LRMatchModule(128, 128),
            nn.MaxPool2d(2, stride=2),

            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            LRMatchModule(256, 256),
            nn.MaxPool2d(2, stride=2)
        )

        self.code_conversion = CodeConversionModule(256 * 8 * 8, hidden_dim)

        # Decoder
        self.decoder = nn.Sequential(
            LRMatchModule(256 + hidden_dim, 256, is_deconv=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2),

            LRMatchModule(128, 128, is_deconv=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),

            LRMatchModule(64, 64, is_deconv=True),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(64, 1, kernel_size=2, stride=2),
            nn.Tanh()
        )

    def forward(self, x):
        laplacian_prior = self.laplacian_prior(x)

        encoded = self.encoder(x)
        batch_size, channels, height, width = encoded.size()
        code = self.code_conversion(encoded)
        code = code.view(batch_size, -1, 1, 1)

        combined_code = torch.cat((encoded, code), dim=1)
        reconstructed = self.decoder(combined_code)

        return reconstructed, laplacian_prior

class CodeConversionModule(nn.Module):
    def __init__(self, in_channels, hidden_dim):
        super(CodeConversionModule, self).__init__()
        self.fc = nn.Linear(in_channels, hidden_dim)

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        x = x.view(batch_size, -1)  # Flatten the input
        x = self.fc(x)
        return x
