In [None]:
import os

os.environ['OPENCV_IO_ENABLE_OPENEXR'] = "1"


In [None]:
import torch
import torch.nn as nn


class Encoder(nn.Module):
    def __init__(self, input_channels):
        super(Encoder, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1,
        )

        self.conv2 = nn.Conv2d(
            in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1
        )

        self.conv3 = nn.Conv2d(
            in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1
        )

        self.relu = nn.ReLU()

    def forward(self, x):
        out_1 = self.conv1(x)
        out_1_relu = self.relu(out_1)

        out_2 = self.conv2(out_1_relu)
        out_2_relu = self.relu(out_2)

        out_3 = self.conv3(out_2_relu)
        out_3_relu = self.relu(out_3)

        return out_3_relu


input_channels = 3 
encoder = Encoder(input_channels)
input_tensor = torch.randn(1, input_channels, 256, 256)
output = encoder(input_tensor)
print(output.shape)


torch.Size([1, 64, 256, 256])


In [None]:
import torch
import torch.nn as nn


class Decoder(nn.Module):
    def __init__(self, input_channels):  
        super(Decoder, self).__init__()
        
        self.conv1 = nn.Conv2d(
            in_channels=input_channels,
            out_channels=32,
            kernel_size=3,
            stride=1,
            padding=1
        )
        
        self.conv2 = nn.Conv2d(
            in_channels=32,
            out_channels=16,
            kernel_size=3,
            stride=1,
            padding=1
        )
        
        self.conv3 = nn.Conv2d(
            in_channels=16,
            out_channels=3,
            kernel_size=3,
            stride=1,
            padding=1
        )
        
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input, img_1, img_2, img_3):
        out_1 = self.conv1(input)
        out_1_relu = self.relu(out_1)
        
        out_2 = self.conv2(out_1_relu)
        out_2_relu = self.relu(out_2)
        
        out_3 = self.conv3(out_2_relu)
        
        out = out_3 + img_1 + img_2 + img_3
        out = self.sigmoid(out)
        
        return out, out_3

# Przykład użycia:
height = 256
width = 256
input_channels = 64  # Dopasuj do wyjścia enkodera lub swoich potrzeb
decoder = Decoder(input_channels=input_channels)
input_tensor = torch.randn(1, input_channels, height, width)
img_1 = torch.randn(1, 3, height, width)
img_2 = torch.randn(1, 3, height, width)
img_3 = torch.randn(1, 3, height, width)
output, out_3 = decoder(input_tensor, img_1, img_2, img_3)

print(output.shape)
print(out_3.shape)


torch.Size([1, 3, 256, 256])
torch.Size([1, 3, 256, 256])


In [None]:
import torch
import torch.nn as nn


class TMONet(nn.Module):
    def __init__(self, input_channels=3):
        super(TMONet, self).__init__()
        
        # Encoder - jedna instancja, współdzielona dla wszystkich wejść
        self.encoder = Encoder(input_channels=input_channels)
        
        # Fusion layers
        self.conv_gated_1 = nn.Conv2d(
            in_channels=64 * 3,  # 64 kanały z każdego enkodera x 3
            out_channels=64 * 3,  # sc w oryginalnym kodzie
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.conv_gated_2 = nn.Conv2d(
            in_channels=64 * 3,
            out_channels=64 * 3,
            kernel_size=1,
            stride=1,
            padding=0  # kernel 1x1 nie wymaga paddingu dla 'SAME'
        )
        
        # Decoder
        self.decoder = Decoder(input_channels=64 * 3)

    def forward(self, input_1, input_2, input_3):
        # Encoder
        output_1 = self.encoder(input_1)
        output_2 = self.encoder(input_2)
        output_3 = self.encoder(input_3)
        
        # Concatenation
        out_concat = torch.cat([output_1, output_2, output_3], dim=1)  # dim=1 dla kanałów
        
        # Fusion
        out_gated_1 = self.conv_gated_1(out_concat)
        out_gated_2 = self.conv_gated_2(out_gated_1)
        
        # Decoder
        out_img, out_res = self.decoder(out_gated_2, input_1, input_2, input_3)
        
        return out_img

# Przykład użycia:
tmo_net = TMONet(input_channels=3)
height = 256
width = 256
input_1 = torch.randn(1, 3, height, width)
input_2 = torch.randn(1, 3, height, width)  
input_3 = torch.randn(1, 3, height, width)
output = tmo_net(input_1, input_2, input_3)
print(output.shape)


torch.Size([1, 3, 256, 256])


In [None]:
import numpy as np
import scipy.stats as st
import torch
import torch.nn.functional as F


def gaussian_kernel(nsig=2, filter_size=13):
    interval = (2 * nsig + 1.0) / filter_size
    ll = np.linspace(-nsig - interval / 2.0, nsig + interval / 2.0, filter_size + 1)
    kern1d = np.diff(st.norm.cdf(ll))
    kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
    kernel = kernel_raw / kernel_raw.sum()
    kernel = kernel.astype(np.float32)
    return kernel  # Zwracamy tylko jądro 2D


def feature_filtered(input_feature, sigma=2, kernel_size_num=13, kernel_size_m=3):
    # Pobierz wymiary tensora wejściowego (batch_size, channels, height, width)
    sb, sc, sy, sx = input_feature.shape  # PyTorch używa formatu BCHW

    # Oblicz jądro Gaussa
    kernel = gaussian_kernel(nsig=sigma, filter_size=kernel_size_num)
    kernel = (
        torch.from_numpy(kernel).to(input_feature.device).float()
    )  # Konwersja na tensor

    # Przygotowanie wag dla splotu głębinowego
    weights_g = kernel.unsqueeze(0).unsqueeze(
        0
    )  # Dodajemy wymiary: [1, 1, height, width]
    weights_g = weights_g.repeat(
        sc, 1, 1, 1
    )  # Powielamy dla każdego kanału: [sc, 1, height, width]

    # Jądro prostokątne (box filter)
    weights_m = torch.ones(
        kernel_size_num,
        kernel_size_num,
        dtype=torch.float32,
        device=input_feature.device,
    )
    weights_m = weights_m.unsqueeze(0).unsqueeze(0)  # [1, 1, height, width]
    weights_m = weights_m.repeat(sc, 1, 1, 1)  # [sc, 1, height, width]

    sum_k = kernel_size_num * kernel_size_num
    sum_k = torch.tensor(sum_k, dtype=torch.float32, device=input_feature.device)

    # Operacje splotu
    ## Gaussian
    out_gaussian = F.conv2d(
        input_feature, weights_g, stride=1, padding=kernel_size_num // 2, groups=sc
    )

    ## Box filter
    out_box_feature_square = F.conv2d(
        input_feature.pow(2),
        weights_m / sum_k,
        stride=1,
        padding=kernel_size_num // 2,
        groups=sc,
    )
    out_box_feature_mean = F.conv2d(
        input_feature,
        weights_m / sum_k,
        stride=1,
        padding=kernel_size_num // 2,
        groups=sc,
    )
    out_box_feature_mean_square = out_box_feature_mean.pow(2)

    # Zabezpieczenie przed zerami w mianowniku
    out_box_feature_mean_square = torch.where(
        out_box_feature_mean_square == 0.0,
        torch.full_like(out_box_feature_mean_square, 1e-8),
        out_box_feature_mean_square,
    )

    diff = out_box_feature_square - out_box_feature_mean_square
    local_std = torch.sqrt(torch.abs(diff) + 1e-8)

    return out_gaussian, local_std, out_box_feature_mean


def local_mean_std(input_feature, sigma=2, kernel_size_num=13, kernel_size_m=3):
    mean_local, std_local, mean_local_box = feature_filtered(
        input_feature,
        sigma=sigma,
        kernel_size_num=kernel_size_num,
        kernel_size_m=kernel_size_m,
    )

    return mean_local, std_local, mean_local_box


def sign_num_den(input, gamma, beta, sigma=2, kernel_size_num=13, kernel_size_m=3):
    local_mean, local_std, local_mean_box = local_mean_std(
        input, sigma=sigma, kernel_size_num=kernel_size_num
    )

    # Licznik (num)
    gaussian_norm = (input - local_mean) / (torch.abs(local_mean) + 1e-8)
    msk = torch.where(
        gaussian_norm > 0.0,
        torch.ones_like(gaussian_norm),
        -torch.zeros_like(gaussian_norm),
    )
    gaussian_norm = torch.where(
        gaussian_norm == 0.0, torch.full_like(gaussian_norm, 1e-8), gaussian_norm
    )
    gaussian_norm = torch.pow(torch.abs(gaussian_norm), gamma)
    norm_num = msk * gaussian_norm

    # Mianownik (den)
    local_norm = local_std / (torch.abs(local_mean_box) + 1e-8)
    norm_den = torch.pow(local_norm, beta)
    norm_den = 1.0 + norm_den

    return norm_num, norm_den


def feature_contrast_masking(input, gamma, beta, sigma_num=2, kernel_size_num=13, kernel_size_den=13):
    norm_num, norm_den = sign_num_den(
        input, gamma=gamma, beta=beta, sigma=sigma_num, 
        kernel_size_num=kernel_size_num, kernel_size_m=kernel_size_den
    )
    out = norm_num / norm_den
    return out

def masking_loss(input_1, input_2, gamma=0.5, beta=0.5, sigma_num=2.0, kernel_size_num=13, kernel_size_den=13):
    auto_loss_output = feature_contrast_masking(
        input_1, gamma=1.0, beta=beta, sigma_num=sigma_num, 
        kernel_size_num=kernel_size_num, kernel_size_den=kernel_size_den
    )
    auto_loss_gt = feature_contrast_masking(
        input_2, gamma=gamma, beta=beta, sigma_num=sigma_num, 
        kernel_size_num=kernel_size_num, kernel_size_den=kernel_size_den
    )
    diff = auto_loss_output - auto_loss_gt  # torch.subtract → bezpośrednie odejmowanie
    diff_abs = torch.abs(diff)
    cost = torch.mean(diff_abs)  # torch.reduce_mean → torch.mean
    return cost


input_1 = torch.randn(1, 3, 32, 32)  # batch_size=1, channels=3, height=32, width=32
input_2 = torch.randn(1, 3, 32, 32)
loss = masking_loss(input_1, input_2, gamma=0.5, beta=0.5)

print(loss.item())  # Wyświetli pojedynczą wartość straty


6.735195159912109


In [None]:
import torch
import torch.nn as nn
from torchvision import models


# Funkcja pomocnicza emulująca VGG19_slim
class VGG19FeatureExtractor(nn.Module):
    def __init__(self, device):
        super(VGG19FeatureExtractor, self).__init__()
        vgg19 = models.vgg19(pretrained=True).features.eval().to(device)
        self.vgg19 = vgg19
        self.layer_map = {
            "VGG11": 0,  # conv1_1
            "VGG21": 5,  # conv2_1
            "VGG22": 9,  # conv2_2
            "VGG31": 10,  # conv3_1
            "VGG34": 18,  # conv3_4
            "VGG41": 19,  # conv4_1
            "VGG51": 28,  # conv5_1
            "VGG54": 34,  # conv5_4
        }

    def forward(self, input, type):
        if type not in self.layer_map:
            raise NotImplementedError(f"Unknown perceptual type: {type}")

        target_layer_idx = self.layer_map[type]
        output = input
        for idx, layer in enumerate(self.vgg19):
            output = layer(output)
            if idx == target_layer_idx:
                break
        return output


# Funkcja FCM_loss dostosowana do VGG19_slim
def FCM_loss(
    input_1,
    input_2,
    feature_extractor,
    gamma=0.5,
    beta=0.5,
    sigma_num=2,
    kernel_size_num=13,
    kernel_size_den=13,
):
    # Ensure inputs are on the correct device
    input_1 = input_1.to(feature_extractor.vgg19[0].weight.device)
    input_2 = input_2.to(feature_extractor.vgg19[0].weight.device)

    # Extract features using the pre-loaded VGG19 model
    x_1 = feature_extractor(input_1, "VGG11")
    gt_1 = feature_extractor(input_2, "VGG11")

    x_2 = feature_extractor(input_1, "VGG21")
    gt_2 = feature_extractor(input_2, "VGG21")

    x_3 = feature_extractor(input_1, "VGG31")
    gt_3 = feature_extractor(input_2, "VGG31")

    # Calculate costs using masking_loss
    cost_1 = masking_loss(
        x_1,
        gt_1,
        gamma=gamma,
        beta=beta,
        sigma_num=sigma_num,
        kernel_size_num=kernel_size_num,
        kernel_size_den=kernel_size_den,
    )
    cost_2 = masking_loss(
        x_2,
        gt_2,
        gamma=gamma,
        beta=beta,
        sigma_num=sigma_num,
        kernel_size_num=kernel_size_num,
        kernel_size_den=kernel_size_den,
    )
    cost_3 = masking_loss(
        x_3,
        gt_3,
        gamma=gamma,
        beta=beta,
        sigma_num=sigma_num,
        kernel_size_num=kernel_size_num,
        kernel_size_den=kernel_size_den,
    )

    # Average loss
    cost_all = (cost_1 + cost_2 + cost_3) / 3.0
    return cost_all


In [None]:
def mul_exp(img):
    x_p = 1.21497  # Stała wartość
    # Obliczanie c_start i c_end w PyTorch
    c_start = torch.log((x_p / torch.max(img)) / torch.log(torch.tensor(2.0, device=img.device)))
    c_end = torch.log((x_p / torch.quantile(img, 0.5)) / torch.log(torch.tensor(2.0, device=img.device)))

    output_list = []
    exp_value = [c_start, (c_end + c_start) / 2.0, c_end]

    for i in range(len(exp_value)):
        sc = torch.pow(torch.sqrt(torch.tensor(2.0, device=img.device)), exp_value[i])
        img_exp = img * sc
        img_pow = img_exp  # Bez zmian, zgodnie z oryginalnym kodem
        img_out = torch.where(img_pow > 1.0, torch.ones_like(img_pow), img_pow)
        output_list.append(img_out)

    return output_list

# Przykład użycia:
img = torch.randn(1, 3, 32, 32)  # Przykładowy tensor obrazu
output_list = mul_exp(img)
print(len(output_list))  # 3
print(output_list[0].shape)  # [1, 3, 32, 32]


3
torch.Size([1, 3, 32, 32])


In [None]:
import os

import cv2
import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

os.environ['OPENCV_IO_ENABLE_OPENEXR'] = "1"

class IOException(Exception):
    def __init__(self, value):
        self.value = value

    def __str__(self):
        return repr(self.value)


def writeLDR(img, file):
    try:
        img = cv2.cvtColor(img.astype(np.float32), cv2.COLOR_BGR2RGB)
        cv2.imwrite(file, img * 255.0)
    except Exception as e:
        raise IOException("Failed writing LDR image: %s" % e)


def norm(x):
    x_max = np.max(x)
    x_min = np.min(x)
    scale = x_max - x_min
    x_norm = (x - x_min) / scale
    return x_norm


def norm_mean(img):
    img = 0.5 * img / img.mean()
    return img


def ulaw_np(img, scale=10.0):
    median_value = np.median(img)
    scale = 8.759 * np.power(median_value, 2.148) + 0.1494 * np.power(
        median_value, -2.067
    )
    out = np.log(1 + scale * img) / np.log(1 + scale)
    return out, scale


def load_hdr_ldr_norm_ulaw(name_hdr):
    y = cv2.imread(name_hdr, flags=cv2.IMREAD_ANYDEPTH | cv2.IMREAD_ANYCOLOR)
    y_rgb = np.maximum(cv2.cvtColor(y, cv2.COLOR_BGR2RGB), 0.0)
    y_rgb = norm_mean(y_rgb)
    y_ulaw, scale = ulaw_np(y_rgb)
    return scale, y_ulaw, y_rgb


class HDRDataset(Dataset):
    def __init__(self, directory, transform=None):
        """
        Args:
            directory (str): Ścieżka do katalogu z plikami HDR.
            transform (callable, optional): Opcjonalna transformacja do zastosowania na danych.
        """
        self.directory = directory
        self.transform = transform

        # Wczytanie listy plików HDR z katalogu
        self.hdr_files = [f for f in os.listdir(directory) if f.endswith(".exr")]
        if not self.hdr_files:
            raise ValueError(f"No HDR files found in directory: {directory}")

    def __len__(self):
        """Zwraca liczbę próbek w datasetcie."""
        return len(self.hdr_files)

    def __getitem__(self, idx):
        """Zwraca próbkę o podanym indeksie."""
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # Ścieżka do pliku HDR
        hdr_path = os.path.join(self.directory, self.hdr_files[idx])

        # Wczytanie i przetworzenie obrazu
        scale, y_ulaw, y_rgb = load_hdr_ldr_norm_ulaw(hdr_path)

        # Konwersja na tensory PyTorch
        y_ulaw = torch.from_numpy(y_ulaw).float()  # [H, W, C]
        y_rgb = torch.from_numpy(y_rgb).float()  # [H, W, C]
        scale = torch.tensor(scale).float()  # Skalar

        # Zamiana osi na format PyTorch [C, H, W]
        y_ulaw = y_ulaw.permute(2, 0, 1)
        y_rgb = y_rgb.permute(2, 0, 1)

        # Zastosowanie opcjonalnej transformacji
        if self.transform:
            y_ulaw = self.transform(y_ulaw)
            y_rgb = self.transform(y_rgb)

        # Zwracanie próbki jako słownik
        sample = {
            "scale": scale,
            "y_ulaw": y_ulaw,
            "y_rgb": y_rgb,
            "filename": self.hdr_files[idx],
        }
        return sample


# Ścieżka do katalogu z plikami HDR
data_dir = "data_resized"  # Zastąp swoją ścieżką

# Inicjalizacja datasetu
dataset = HDRDataset(directory=data_dir)
print(f"Loaded {len(dataset)} HDR images")
# Tworzenie DataLoader'a
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)

# Iteracja po danych
for batch in dataloader:
    print("Batch keys:", batch.keys())
    print("Scale shape:", batch["scale"].shape)
    print("y_ulaw shape:", batch["y_ulaw"].shape)
    print("y_rgb shape:", batch["y_rgb"].shape)
    print("Filenames:", batch["filename"])
    break  # Przerwij po pierwszym batchu dla przykładu


Loaded 181 HDR images
Batch keys: dict_keys(['scale', 'y_ulaw', 'y_rgb', 'filename'])
Scale shape: torch.Size([2])
y_ulaw shape: torch.Size([2, 3, 256, 256])
y_rgb shape: torch.Size([2, 3, 256, 256])
Filenames: ['174.exr', '033.exr']


In [None]:
import os

import brisque
import numpy as np
import torch
from brisque import BRISQUE
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data import DataLoader, random_split
from torchvision.utils import save_image
from tqdm.notebook import tqdm


# Funkcja do obliczania luminancji (przyjmuję, że masz ją zdefiniowaną gdzie indziej, jeśli nie, dodam przykład poniżej)
def lum(img):
    # Przykład: luminancja jako ważona suma kanałów RGB
    return 0.299 * img[:, :, 0] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 2]


def evaluate_image(image) -> float:
    metric = BRISQUE(url=False)
    if isinstance(image, torch.Tensor):
        image = image.cpu().detach().numpy()
        image = np.transpose(image, (1, 2, 0))  # From (C, H, W) to (H, W, C)
        image = (image * 255).clip(0, 255).astype(np.uint8)  # Scale to [0, 255]
    return metric.score(img=image)


def reinhard_tone_mapping(hdr_image):
    """
    Applies Reinhard global tone mapping to an HDR image.
    hdr_image: numpy array of shape (H, W, C)
    """
    # Compute luminance
    luminance = lum(hdr_image)
    # Compute log-average luminance
    log_avg_lum = np.exp(np.mean(np.log(luminance + 1e-8)))  # Avoid log(0)
    # Scale luminance
    scaled_lum = (luminance * (1.0 / log_avg_lum)) / (
        1.0 + luminance * (1.0 / log_avg_lum)
    )
    # Apply to each channel
    tone_mapped = np.zeros_like(hdr_image)
    for c in range(3):
        tone_mapped[:, :, c] = (hdr_image[:, :, c] / (luminance + 1e-8)) * scaled_lum
    return np.clip(tone_mapped, 0, 1)


config = {
    "epochs": 10,
    "batch_size": 2,
    "learning_rate": 2e-4,
    "decay_rate": 0.9,
    "data_dir": "data_resized",
    "validation_split": 0.2,
    "valid_step": 1,
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "output_dir": "output_images",
}


def train(config: dict) -> None:
    print(config)
    device = torch.device(config["device"])

    # Dataset i dataloaders
    base_dataset = HDRDataset(directory=config["data_dir"])
    train_size = int((1 - config["validation_split"]) * len(base_dataset))
    val_size = len(base_dataset) - train_size
    train_dataset, val_dataset = random_split(base_dataset, [train_size, val_size])

    train_loader = DataLoader(
        train_dataset, batch_size=config["batch_size"], shuffle=True
    )
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False)

    # Model, optimizer, scheduler
    model = TMONet(input_channels=3).to(device)
    feature_extractor = VGG19FeatureExtractor(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config["learning_rate"])
    scheduler = ExponentialLR(optimizer, gamma=config["decay_rate"])

    os.makedirs(config["output_dir"], exist_ok=True)
    reference_dir = os.path.join(config["output_dir"], "reference")
    os.makedirs(reference_dir, exist_ok=True)

    total_batches = len(train_loader) + (
        len(val_loader) if config["valid_step"] > 0 else 0
    )

    with torch.no_grad():
        brisque_total = 0.0
        num_samples = 0
        for batch_idx, batch in enumerate(
            tqdm(val_loader, desc="Processing Validation Data", unit="batch")
        ):
            hdr_rgb = batch["y_rgb"].float().to(device)  # Ground truth HDR images

            for i in range(hdr_rgb.size(0)):
                # Convert tensor to numpy
                hdr_np = hdr_rgb[i].cpu().numpy().transpose(1, 2, 0)  # (H, W, C)

                # Apply Reinhard tone mapping
                tone_mapped_np = reinhard_tone_mapping(hdr_np)

                # Convert back to tensor for saving
                tone_mapped_tensor = torch.from_numpy(
                    tone_mapped_np.transpose(2, 0, 1)
                ).float()  # (C, H, W)

                # Save the image
                img_filename = os.path.join(
                    reference_dir, f"val_{batch_idx * config['batch_size'] + i}.png"
                )
                save_image(tone_mapped_tensor, img_filename)

                # Calculate BRISQUE score
                brisque_score = evaluate_image(tone_mapped_tensor)
                brisque_total += brisque_score
                num_samples += 1

    # Compute average BRISQUE score
    reinhard_brisque_score = brisque_total / num_samples if num_samples > 0 else 0.0

    for epoch in range(1, config["epochs"] + 1):
        progress_bar = tqdm(
            total=total_batches, desc=f"Epoch {epoch}/{config['epochs']}", unit="batch"
        )

        # Training phase
        model.train()
        for step, batch in enumerate(train_loader):
            hdr_ulaw = batch["y_ulaw"].float().to(device)
            hdr_rgb = batch["y_rgb"].float().to(device)

            optimizer.zero_grad()
            output_list = mul_exp(hdr_ulaw)
            output = model(output_list[0], output_list[1], output_list[2])

            loss = FCM_loss(output, hdr_rgb, feature_extractor)
            loss.backward()
            optimizer.step()

            progress_bar.set_postfix(stage="Training", loss=loss.item())
            progress_bar.update(1)

        # Validation phase
        if epoch % config["valid_step"] == 0:
            os.makedirs(os.path.join(config["output_dir"], f"{epoch}"), exist_ok=True)
            model.eval()
            val_loss_total = 0.0
            val_brisque_total = 0.0
            num_samples = 0

            with torch.no_grad():
                for batch_idx, batch in enumerate(val_loader):
                    hdr_ulaw = batch["y_ulaw"].float().to(device)
                    hdr_rgb = batch["y_rgb"].float().to(device)

                    output_list = mul_exp(hdr_ulaw)
                    output = model(output_list[0], output_list[1], output_list[2])

                    val_loss = FCM_loss(output, hdr_rgb, feature_extractor)
                    val_loss_total += val_loss.item() * output.size(0)

                    # Przetwarzanie obrazów w batchu
                    for i in range(output.size(0)):
                        # Konwersja tensorów na numpy
                        xx = hdr_ulaw[i].cpu().numpy().transpose(1, 2, 0)  # (H, W, C)
                        yy = hdr_rgb[i].cpu().numpy().transpose(1, 2, 0)  # (H, W, C)
                        y_pred = output[i].cpu().numpy().transpose(1, 2, 0)  # (H, W, C)

                        # Parametr a z pierwszego kodu
                        a = 0.6

                        # Rozdziel kanały RGB
                        r = yy[:, :, 0]
                        g = yy[:, :, 1]
                        b = yy[:, :, 2]

                        # Oblicz luminancję
                        y_gt_lum_np = lum(yy)
                        yy_predict_np_lum = lum(y_pred)

                        # Transformacja obrazu
                        img_out = np.zeros(np.shape(yy))
                        img_out[:, :, 0] = (
                            r / (y_gt_lum_np + 1e-8)
                        ) ** a * yy_predict_np_lum
                        img_out[:, :, 1] = (
                            g / (y_gt_lum_np + 1e-8)
                        ) ** a * yy_predict_np_lum
                        img_out[:, :, 2] = (
                            b / (y_gt_lum_np + 1e-8)
                        ) ** a * yy_predict_np_lum

                        # Dodaj małą wartość w mianowniku, aby uniknąć dzielenia przez 0
                        img_out = np.clip(
                            img_out, 0, 1
                        )  # Ograniczenie do zakresu [0, 1]

                        # Konwersja z powrotem na tensor do zapisu
                        img_tensor = torch.from_numpy(
                            img_out.transpose(2, 0, 1)
                        ).float()  # (C, H, W)

                        # Ocena BRISQUE
                        brisque_score = evaluate_image(img_tensor)
                        val_brisque_total += brisque_score

                        # Zapis obrazu
                        img_filename = os.path.join(
                            config["output_dir"],
                            f"{epoch}/{batch_idx * config['batch_size'] + i}.png",
                        )
                        save_image(img_tensor, img_filename)

                    num_samples += output.size(0)
                    progress_bar.set_postfix(stage="Validation")
                    progress_bar.update(1)

            # Średnie metryki
            avg_val_loss = val_loss_total / num_samples
            avg_brisque_score = val_brisque_total / num_samples

            progress_bar.set_postfix(
                stage="Validation",
                val_loss=avg_val_loss,
                val_brisque=avg_brisque_score,
                reinhard_brisque_score=reinhard_brisque_score,
            )
            print(
                f"Epoch {epoch}/{config['epochs']} - Avg Val Loss: {avg_val_loss:.4f}, Avg BRISQUE: {avg_brisque_score:.4f}, Avg BRISQUE Reinhard: {reinhard_brisque_score:.4f}"
            )

        scheduler.step()
        progress_bar.close()


train(config)


{'epochs': 10, 'batch_size': 2, 'learning_rate': 0.0002, 'decay_rate': 0.9, 'data_dir': 'data_resized', 'validation_split': 0.2, 'valid_step': 1, 'device': 'cuda', 'output_dir': 'output_images'}


Processing Validation Data:   0%|          | 0/19 [00:00<?, ?batch/s]

Epoch 1/10:   0%|          | 0/91 [00:00<?, ?batch/s]

Epoch 1/10 - Avg Val Loss: 0.7213, Avg BRISQUE: 22.9678, Avg BRISQUE Reinhard: 21.4222


Epoch 2/10:   0%|          | 0/91 [00:00<?, ?batch/s]

KeyboardInterrupt: 