<a href="https://www.kaggle.com/code/shokhjahonisroilov/upsolving-kazakhstan-25?scriptVersionId=294153117" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Baseline Solution

### Imports

In [None]:
import os
import random
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import torch.nn.functional as F
import torchvision.models as models

In [None]:
train_path = "/kaggle/input/up-solving-tst-day-1/train/train/"
test_path = "/kaggle/input/up-solving-tst-day-1/test/test/"

In [None]:
class CombinedPairedImageDataset(Dataset):
    def __init__(self, root_dir, all_filters, transform=None):
        """
        root_dir/
            real_images/
            filtered_images/
        """

        self.real_images_dir = os.path.join(root_dir, 'real_images')
        self.filtered_images_dir = os.path.join(root_dir, 'filtered_images')

        self.image_filenames = sorted(os.listdir(self.real_images_dir))
        self.filters = all_filters
        self.num_filters = len(all_filters)

        self.transform = transform or transforms.ToTensor()

        # количество данных
        self.num_real_pairs = len(self.image_filenames)
        self.num_augmented_pairs = self.num_real_pairs * self.num_filters

    def __len__(self):
        return self.num_real_pairs + self.num_augmented_pairs

    def __getitem__(self, idx):
        # ---------- 1. Реальные пары ----------
        if idx < self.num_real_pairs:
            image_name = self.image_filenames[idx]

            filtered_image_path = os.path.join(self.filtered_images_dir, image_name)
            original_image_path = os.path.join(self.real_images_dir, image_name)

            filtered_image = Image.open(filtered_image_path).convert('RGB')
            original_image = Image.open(original_image_path).convert('RGB')

            filtered_image = self.transform(filtered_image)
            original_image = self.transform(original_image)

            return filtered_image, original_image

        # ---------- 2. Синтетические пары ----------
        augmented_idx = idx - self.num_real_pairs

        image_index = augmented_idx // self.num_filters
        filter_index = augmented_idx % self.num_filters

        image_name = self.image_filenames[image_index]
        image_path = os.path.join(self.real_images_dir, image_name)

        original_image = Image.open(image_path).convert('RGB')
        original_image = self.transform(original_image)

        filter_pattern = self.filters[filter_index]
        filtered_image = apply_filter(original_image, filter_pattern)

        return filtered_image, original_image


In [None]:
class FilteredImageDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.filtered_dir = os.path.join(root_dir, 'filtered_images')
        self.image_filenames = sorted(os.listdir(self.filtered_dir))
        self.transform = transform or transforms.ToTensor()

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        filename = self.image_filenames[idx]
        filtered_path = os.path.join(self.filtered_dir, filename)

        filtered_img = Image.open(filtered_path).convert('RGB')
        filtered_img = self.transform(filtered_img)

        return filtered_img, filename  


In [None]:
train_dataset = CombinedPairedImageDataset(train_path, all_filters)  # path/filtered_images и path/real_images
test_dataset = FilteredImageDataset(test_path)  # path_test/filtered_images

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True
)

### Model

In [None]:


class ResBlock(nn.Module):
    def __init__(self, c):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(c, c, 3, padding=1),
            nn.GroupNorm(8, c),
            nn.ReLU(inplace=True),
            nn.Conv2d(c, c, 3, padding=1),
            nn.GroupNorm(8, c)
        )
    def forward(self, x):
        return x + self.net(x)

class Down(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.down = nn.Sequential(
            nn.Conv2d(c_in, c_out, 3, stride=2, padding=1),
            nn.GroupNorm(8, c_out),
            nn.ReLU(inplace=True),
            ResBlock(c_out)
        )
    def forward(self, x):
        return self.down(x)

class Up(nn.Module):
    def __init__(self, c_in, c_out):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(c_in, c_out, 3, padding=1),
            nn.GroupNorm(8, c_out),
            nn.ReLU(inplace=True),
            ResBlock(c_out)
        )
    def forward(self, x, skip):
        x = self.up(x)
        return x + skip

class CNN(nn.Module):
    def __init__(self, base=64):
        super().__init__()
        self.enc1 = nn.Sequential(
            nn.Conv2d(3, base, 3, padding=1),
            nn.GroupNorm(8, base),
            nn.ReLU(inplace=True),
            ResBlock(base)
        )
        self.enc2 = Down(base, base*2)
        self.enc3 = Down(base*2, base*4)

        self.center = nn.Sequential(
            ResBlock(base*4),
            ResBlock(base*4)
        )

        self.dec2 = Up(base*4, base*2)
        self.dec1 = Up(base*2, base)

        self.tail = nn.Sequential(
            nn.Conv2d(base, 3, 3, padding=1)
        )

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        c = self.center(e3)
        d2 = self.dec2(c, e2)
        d1 = self.dec1(d2, e1)
        out = self.tail(d1)
        return torch.clamp(out + x, 0.0, 1.0)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def compute_psnr(img1, img2):
    # [C, H, W] -> [H, W, C] and to uint8
    img1 = (np.transpose(img1, (1, 2, 0)) * 255).round().astype(np.uint8)
    img2 = (np.transpose(img2, (1, 2, 0)) * 255).round().astype(np.uint8)

    mse = np.mean((img1.astype(np.float32) - img2.astype(np.float32)) ** 2)
    if mse == 0:
        return float('inf')
    
    psnr = 10 * np.log10((255 ** 2) / mse)
    return psnr

In [None]:
def show_image(i: int, inputs, outputs, targets):
    filtered_img = inputs[i].cpu().detach().numpy()
    output_img = outputs[i].cpu().detach().clamp(0,1).numpy()
    real_img = targets[i].cpu().detach().numpy()

    print(f"Image PSNR {compute_psnr(output_img, real_img)}")
    
    plt.figure(figsize=(12,4))
    
    plt.subplot(1,3,1)
    plt.title('Filtered (Input)')
    plt.axis('off')
    plt.imshow(np.transpose(filtered_img, (1, 2, 0)))
    
    plt.subplot(1,3,2)
    plt.title('Model Output')
    plt.axis('off')
    plt.imshow(np.transpose(output_img, (1, 2, 0)))
    
    plt.subplot(1,3,3)
    plt.title('Real (Target)')
    plt.axis('off')
    plt.imshow(np.transpose(real_img, (1, 2, 0)))
    
    plt.show()

In [None]:
model = CNN().to(device)


criterion = nn.L1Loss(reduction='none')
optimizer = optim.Adam(model.parameters(), lr=1e-3)

### Train

In [None]:
epochs = 7

for epoch in range(epochs):
    model.train()
    running_loss = 0.0

    for inputs, targets in train_loader:
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets).mean()
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(train_loader.dataset)
    print(f"Epoch [{epoch+1}/{epochs}] Loss: {epoch_loss:.4f}")

    show_image(
        0,
        inputs.detach().cpu(),
        outputs.detach().cpu(),
        targets.detach().cpu()
    )

### Predict

In [None]:
submission_rows = []

model.eval()
for idx in range(len(test_dataset)):
    img_tensor, filename = test_dataset[idx]
    img_tensor = img_tensor.unsqueeze(0).to(device)  # [1, 3, 128, 128]
    img_id = os.path.splitext(filename)[0]

    with torch.no_grad():
        output = model(img_tensor)  # [1, 3, 128, 128]
        output = output.squeeze(0).clamp(0, 1).cpu().numpy()  # [3, 128, 128]

    output_bgr = output[[2, 1, 0], :, :]
    output_bgr = (output_bgr * 255).round().astype(np.uint8)
    output_flat = output_bgr.transpose(1, 2, 0).flatten()

    submission_rows.append([img_id] + output_flat.tolist())

In [None]:
pixel_columns = [str(i) for i in range(output_flat.size)]
df = pd.DataFrame(submission_rows, columns=["id"] + pixel_columns)

df.to_csv("laplacian.csv", index=False)