In [1]:
import os
import random
import numpy as np
from glob import glob
from tqdm.auto import tqdm
from scipy.optimize import linear_sum_assignment
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights

import os
import time
import numpy as np
from datetime import timedelta
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torch.nn import InstanceNorm2d, LeakyReLU, Tanh, Sigmoid
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import deeplay as dl
import matplotlib.pyplot as plt
from glob import glob
import os


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
class ColorizationDataset(Dataset):
    def __init__(self, image_paths, image_size=(256, 256)):
        self.image_paths = image_paths
        self.image_size = image_size
        self.rgb_transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),
        ])
        self.gray_transform = transforms.Compose([
            transforms.Resize(self.image_size),
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),
            transforms.Normalize((0.5,),(0.5,)),
        ])
    def __len__(self):
        return len(self.image_paths)
    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx]).convert('RGB')
        gray = self.gray_transform(img)
        rgb = self.rgb_transform(img)
        return gray, rgb

In [4]:
from pathlib import Path
class GrayDataset(Dataset):
    """
    A dataset that loads images from a directory (and subdirectories) as single-channel grayscale tensors.

    Args:
        root_dir (str or Path): Path to the directory containing images.
        extensions (tuple of str): Allowed file extensions, e.g. ('.jpg', '.png').
    """
    def __init__(self, root_dir, extensions=('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
        self.root_dir = Path(root_dir)
        # Recursively collect image file paths
        self.paths = sorted(
            [p for p in self.root_dir.rglob('*') if p.suffix.lower() in extensions]
        )
        if not self.paths:
            raise FileNotFoundError(f"No images found in {root_dir} with extensions {extensions}")

        # Grayscale + ToTensor transform
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),  # outputs [1, H, W] in [0,1]
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        with Image.open(img_path) as img:
            img = img.convert('RGB')  # ensure 3-channel input for Grayscale
            clean = self.transform(img)
        return clean


class NoisyGrayDataset(Dataset):
    """
    A dataset that wraps GrayDataset and returns noisy and clean grayscale tensors.

    Args:
        root_dir (str or Path): Path to the directory containing images.
        noise_std (float): Standard deviation of Gaussian noise added to clean image.
    """
    def __init__(self, root_dir, noise_std=0.3, extensions=('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
        self.clean_ds = GrayDataset(root_dir, extensions)
        self.noise_std = noise_std

    def __len__(self):
        return len(self.clean_ds)

    def __getitem__(self, idx):
        clean = self.clean_ds[idx]  # [1, H, W]
        noise = torch.randn_like(clean) * self.noise_std
        noisy = torch.clamp(clean + noise, 0.0, 1.0)
        return noisy, clean

In [5]:
import os
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# Assume your dataset classes are already defined: GrayDataset, NoisyGrayDataset

import os
import random
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader

def visualize_and_save_noisy_dataset(dataset, out_dir='output_images', max_images=10, seed=None):
    os.makedirs(out_dir, exist_ok=True)

    # Optional: set random seed for reproducibility
    if seed is not None:
        random.seed(seed)

    # Randomly select `max_images` indices from the dataset
    total = len(dataset)
    indices = random.sample(range(total), k=min(max_images, total))

    for i, idx in enumerate(indices):
        noisy, clean = dataset[idx]  # Directly index into dataset

        # Save using torchvision
        #save_image(clean, os.path.join(out_dir, f'gray_{i:03d}.png'), normalize=False)
        save_image(noisy, os.path.join(out_dir, f'noisy_{i:03d}.png'), normalize=False)

        # Optional: Also visualize inline with matplotlib
        # fig, axs = plt.subplots(1, 2, figsize=(6, 3))
        # axs[0].imshow(clean.squeeze().cpu(), cmap='gray')
        # axs[0].set_title('Clean')
        # axs[0].axis('off')
        # axs[1].imshow(noisy.squeeze().cpu(), cmap='gray')
        # axs[1].set_title('Noisy')
        # axs[1].axis('off')
        # plt.tight_layout()
        # plt.show()


root_dir = 'data/sub'  # Replace with actual path
dataset = NoisyGrayDataset(root_dir, noise_std=0.05)
visualize_and_save_noisy_dataset(dataset, out_dir='Noisy_gray', max_images=100)


## jigsaw


In [6]:
from PIL import Image
import torch
from torchvision import transforms

def build_and_save_image(tiles: torch.Tensor,
                         perm: torch.Tensor,
                         grid_size: tuple[int, int],
                         tile_size: int,
                         save_path: str,
                         fixed: False):
    """
    Reconstructs a single image from shuffled tiles and saves it to disk.

    Args:
        tiles (Tensor[N, C, T, T]): Shuffled tiles (torch tensor).
        perm  (Tensor[N]): Permutation such that tiles[k] came from original slot perm[k].
        grid_size (H, W): Number of tiles vertically (H) and horizontally (W).
        tile_size   (int): Width & height of each square tile in pixels.
        save_path   (str): Path where the reconstructed image will be saved.
    """
    H, W = grid_size
    N = H * W
    assert tiles.shape[0] == N, f"Expected {N} tiles but got {tiles.shape[0]}"
    # Compute inverse permutation: inv_perm[original_idx] = position in `tiles`
    if fixed:
        inv_perm = perm.argsort()
        # Reorder tiles back to their original scanning order
        ordered_tiles = tiles[inv_perm]
          # shape [N, C, T, T]
    else:
        # Reorder tiles back to their original scanning order
        ordered_tiles = tiles[perm]
          # shape [N, C, T, T]
    # Convert each tensor tile to a PIL image
    to_pil = transforms.ToPILImage()
    pil_tiles = [to_pil(t) for t in ordered_tiles.cpu()]

    # Create a blank canvas
    mode = pil_tiles[0].mode
    canvas = Image.new(mode, (W * tile_size, H * tile_size))

    # Paste each tile into its spot
    for idx, tile in enumerate(pil_tiles):
        i = idx // W  # row
        j = idx % W   # col
        canvas.paste(tile, (j * tile_size, i * tile_size))

    # Save to disk
    os.makedirs(os.path.dirname(save_path), exist_ok=True)
    canvas.save(save_path)


In [7]:
# Main
import os
import random
import numpy as np
from glob import glob
from tqdm.auto import tqdm
from scipy.optimize import linear_sum_assignment
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torchvision.models import ResNet50_Weights
from pathlib import Path


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class GrayDataset(Dataset):
    """
    A dataset that loads images from a directory (and subdirectories) as single-channel grayscale tensors.

    Args:
        root_dir (str or Path): Path to the directory containing images.
        extensions (tuple of str): Allowed file extensions, e.g. ('.jpg', '.png').
    """
    def __init__(self, root_dir, extensions=('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
        self.root_dir = Path(root_dir)
        # Recursively collect image file paths
        self.paths = sorted(
            [p for p in self.root_dir.rglob('*') if p.suffix.lower() in extensions]
        )
        if not self.paths:
            raise FileNotFoundError(f"No images found in {root_dir} with extensions {extensions}")

        # Grayscale + ToTensor transform
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.ToTensor(),  # outputs [1, H, W] in [0,1]
        ])

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_path = self.paths[idx]
        with Image.open(img_path) as img:
            img = img.convert('RGB')  # ensure 3-channel input for Grayscale
            clean = self.transform(img)
        return clean


class NoisyGrayDataset(Dataset):
    """
    A dataset that wraps GrayDataset and returns noisy and clean grayscale tensors.

    Args:
        root_dir (str or Path): Path to the directory containing images.
        noise_std (float): Standard deviation of Gaussian noise added to clean image.
    """
    def __init__(self, root_dir, noise_std=0.3, extensions=('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff')):
        self.clean_ds = GrayDataset(root_dir, extensions)
        self.noise_std = noise_std

    def __len__(self):
        return len(self.clean_ds)

    def __getitem__(self, idx):
        clean = self.clean_ds[idx]  # [1, H, W]
        noise = torch.randn_like(clean) * self.noise_std
        noisy = torch.clamp(clean + noise, 0.0, 1.0)
        return noisy, clean
# Define transformation for tile extraction
def get_tile_transform(tile_size, noise_std, augment=False):
    transform_list = [
        transforms.Grayscale(num_output_channels=1),
        transforms.ToTensor(),  # [1, H, W], range [0,1]
    ]
    return transforms.Compose(transform_list)

# Tile extraction as given
def extract_tiles(img, grid_size, tile_size, overlap, augment, noise_std):
    H, W = grid_size
    w, h = img.size
    cell_w, cell_h = w / W, h / H
    transform = get_tile_transform(tile_size, noise_std, augment)
    tiles = []
    for i in range(H):
        for j in range(W):
            x0 = int(max(0, j*cell_w - overlap))
            y0 = int(max(0, i*cell_h - overlap))
            x1 = int(min(w, (j+1)*cell_w + overlap))
            y1 = int(min(h, (i+1)*cell_h + overlap))
            crop = img.crop((x0, y0, x1, y1)).resize((tile_size, tile_size), Image.BICUBIC)
            tiles.append(transform(crop))
    return torch.stack(tiles)  # shape: [H*W, 1, tile_size, tile_size]

# Dataset that returns tiles with noise
class TileNoisyGrayDataset(Dataset):
    def __init__(self, root_dir, grid_size=(4, 4), tile_size=64, overlap=0, noise_std=0.3, augment=False):
        self.base_dataset = GrayDataset(root_dir)
        self.grid_size = grid_size
        self.tile_size = tile_size
        self.overlap = overlap
        self.noise_std = noise_std
        self.augment = augment

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx):
        clean_img = self.base_dataset[idx]  # [1, H, W]
        img_pil = transforms.ToPILImage()(clean_img)

        clean_tiles = extract_tiles(img_pil, self.grid_size, self.tile_size, self.overlap, self.augment, noise_std=0)
        noise = torch.randn_like(clean_tiles) * self.noise_std
        noisy_tiles = torch.clamp(clean_tiles + noise, 0.0, 1.0)

        # Shuffle the tiles and get permutation
        num_tiles = noisy_tiles.size(0)
        perm = torch.randperm(num_tiles)
        shuffled_noisy_tiles = noisy_tiles[perm]

        return shuffled_noisy_tiles, perm

    
class LSCE(nn.Module):
    def __init__(self, eps=0.05):
        super().__init__()
        self.eps = eps

    def forward(self, logits, target):
        """
        logits: [B*N, N]
        target: [B*N]
        """
        logp = F.log_softmax(logits, dim=1)
        nll = -logp[torch.arange(target.size(0)), target]
        smooth = -logp.mean(dim=1)
        return ((1 - self.eps) * nll + self.eps * smooth).mean()
    
class JigsawModel(nn.Module):
    def __init__(
        self,
        grid_size=(3, 3),
        backbone='resnet50',
        weights=ResNet50_Weights.IMAGENET1K_V2,
        nhead=8,
        num_layers=2,
        dropout=0.1,
        noise_std=0.05,
    ):
        super().__init__()
        H, W = grid_size
        self.N = H * W

        # --- Gaussian noise augmentation for noisy tiles ---
        #self.noise = GaussianNoise(std=noise_std)

        # --- CNN encoder (ResNet50 trunk w/o final FC) adapted for 1-channel input ---
        base = getattr(models, backbone)(weights=weights)
        # Replace first conv to accept single-channel input
        orig_conv = base.conv1
        new_conv = nn.Conv2d(
            in_channels=1,
            out_channels=orig_conv.out_channels,
            kernel_size=orig_conv.kernel_size,
            stride=orig_conv.stride,
            padding=orig_conv.padding,
            bias=(orig_conv.bias is not None)
        )
        # Initialize new conv weights by averaging across original RGB channels
        with torch.no_grad():
            new_conv.weight[:] = orig_conv.weight.mean(dim=1, keepdim=True)
            if orig_conv.bias is not None:
                new_conv.bias[:] = orig_conv.bias
        base.conv1 = new_conv

        # Remove final FC
        self.encoder = nn.Sequential(*list(base.children())[:-1])
        self.embed_dim = base.fc.in_features

        # --- Learnable positional embeddings, one per slot ---
        self.pos_emb = nn.Parameter(torch.zeros(1, self.N, self.embed_dim))
        nn.init.trunc_normal_(self.pos_emb, std=0.02)

        # --- Transformer Encoder over the sequence of tile embeddings ---
        enc_layer = nn.TransformerEncoderLayer(
            d_model=self.embed_dim,
            nhead=nhead,
            dropout=dropout,
            batch_first=True
        )
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers=num_layers)

        # --- Per-token MLP head to predict slot index 0…N-1 ---
        self.head = nn.Sequential(
            nn.Linear(self.embed_dim, self.embed_dim),
            nn.ReLU(inplace=True),
            nn.Linear(self.embed_dim, self.N)
        )

    def forward(self, x):
        """
        x: Tensor[B, N, 1, T, T] (grayscale tiles)
        returns logits: Tensor[B, N, N]
        """
        B, N, C, T, _ = x.shape
        # 1) Add noise to each tile
        x = x.view(B * N, C, T, T).view(B, N, C, T, T)

        # 2) CNN encode each tile → (B*N, D)
        feats = self.encoder(x.view(B * N, C, T, T)).view(B, N, -1)

        # 3) Add positional embedding
        feats = feats + self.pos_emb

        # 4) Contextualize with Transformer
        feats = self.transformer(feats)

        # 5) Predict a distribution over target slot for each input tile
        logits = self.head(feats)
        return logits

    @torch.no_grad()
    def hungarian_assign(self, logits, temperature=1.0):
        """
        Solve assignment per batch element via Hungarian on -logit scores.
        Returns: Tensor[B, N] giving the assigned slot index for each tile.
        """
        scores = (logits / temperature).cpu().numpy()
        perms = []
        for mat in scores:
            _, col_ind = linear_sum_assignment(-mat)
            perms.append(torch.tensor(col_ind, dtype=torch.long))
        return torch.stack(perms, dim=0)


def train_one_epoch(model, loader, optim, scheduler, epoch, freeze_epochs=5):
    model.train()
    # freeze encoder if desired
    for p in model.encoder.parameters():
        p.requires_grad = (epoch > freeze_epochs)

    running_loss = 0.0
    running_correct = 0
    running_tiles = 0

    for tiles, perm in tqdm(loader, desc=f"Train Epoch {epoch}"):
        tiles, perm = tiles.to(device), perm.to(device)

        optim.zero_grad()
        logits = model(tiles)                # [B, N, N]
        B, N, _ = logits.shape

        # --- loss ---
        loss = LSCE()(logits.view(B*N, -1), perm.view(-1))
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optim.step()

        # --- accuracy ---
        with torch.no_grad():
            preds = model.hungarian_assign(logits)  # [B, N], on CPU
            preds = preds.to(device)
            running_correct += (preds == perm).sum().item()
            running_tiles   += B * N

        running_loss += loss.item() * B

    scheduler.step()

    avg_loss = running_loss / len(loader.dataset)
    accuracy = running_correct / running_tiles
    return avg_loss, accuracy


@torch.no_grad()
def validate_one_epoch_imagewise(model, loader, device):
    model.eval()
    criterion = LSCE()

    total_loss = 0.0
    total_tiles = correct_tiles = total_images = correct_images = 0
    i = 0

    for tiles, perm in loader:
        tiles, perm = tiles.to(device), perm.to(device)
        B, N = perm.shape

        # Reconstruct and save the *first* image in the batch:
        # build_and_save_image(
        #     tiles[0].cpu(),     # [N, C, T, T]
        #     perm[0].cpu(),      # [N]
        #     grid_size=(3,3),
        #     tile_size=96,
        #     save_path=f'reconstructions/original_{i}.png',
        #     fixed=False
        # )

        # Forward & loss
        logits = model(tiles)
        total_loss += criterion(
            logits.view(B*N, -1),
            perm.view(-1)
        ).item() * B

        # Hungarian assignment & save reconstruction
        preds = model.hungarian_assign(logits).to(device)
        # build_and_save_image(
        #     tiles[0].cpu(),
        #     preds[0].cpu(),
        #     grid_size=(3,3),
        #     tile_size=96,
        #     save_path=f'reconstructions/pred_{i}.png',
        #     fixed=True
        # )

        # Accuracy
        eq = (preds == perm)
        correct_tiles += eq.sum().item()
        total_tiles   += B * N
        correct_images += eq.all(dim=1).sum().item()
        total_images   += B

        i += 1

    avg_loss = total_loss / len(loader.dataset)
    tile_acc = correct_tiles / total_tiles
    image_acc = correct_images / total_images

    return avg_loss, tile_acc, image_acc


def main():

    epochs, bs = 300, 1
    grid_size  = (3,3)
    tile_size  = 96
    overlap    = 0       # fixed
    noise_std  = 0.1

    # Model, optimizer, scheduler
    model = JigsawModel(grid_size=grid_size, num_layers=2, nhead=8, dropout=0.1).to(device)
    model.load_state_dict(
        torch.load('checks/best_jigsaw.pth', map_location=device)
    )

    val_ds   = TileNoisyGrayDataset('data/test',   grid_size, tile_size, overlap, noise_std)
    val_dl   = DataLoader(val_ds,   batch_size=bs, shuffle=False, num_workers=0, pin_memory=True)

    val_loss,   val_acc   , image_acc= validate_one_epoch_imagewise(model, val_dl, device)

    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4%} | "
    f"image_acc for val: {image_acc:.4f}")


main()


Val Loss: 0.3456 | Val Acc: 98.1720% | image_acc for val: 0.9386


## denoise


In [11]:
import torch

def evaluate_loss(model, data_loader, device):
    """
    Evaluates the average MSE loss of a model over a dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        data_loader (torch.utils.data.DataLoader): DataLoader returning (noisy, clean) image pairs.
        device (torch.device): Device to perform computation on.

    Returns:
        float: Average MSE loss over the dataset.
    """
    model.eval()
    loss_fn = torch.nn.MSELoss()
    total_loss = 0.0
    count = 0

    with torch.no_grad():
        for noisy, clean in data_loader:
            noisy = noisy.to(device)
            clean = clean.to(device)

            output = model(noisy).clamp(0, 1)
            loss = loss_fn(output, clean)

            total_loss += loss.item()
            count += 1

    avg_loss = total_loss / count if count > 0 else float('nan')
    return avg_loss



In [None]:
import torch
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import deeplay as dl  # same library you used for UNet2d

# 1. Device

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 1. Recreate the same Regressor + UNet template
unet = dl.UNet2d(
    in_channels=1,
    channels=[16, 32, 65, 64, 128],
    out_channels=1,
    skip=dl.Cat(),
)
regressor = dl.Regressor(
    model=unet,
    loss=torch.nn.MSELoss(),
    optimizer=dl.Adam(lr=1e-3),
).create().to(device)

# 2. Load the entire trained state
regressor.load_state_dict(torch.load("denoiser_model_01.pth", map_location=device))
regressor.eval()

# 3. The denoiser UNet is now at
denoiser_net = regressor.model

root_data_dir = 'data/test'
test_set    = NoisyGrayDataset(root_data_dir, noise_std=0.1)
test_loader = DataLoader(test_set, batch_size=32, shuffle=False,num_workers=0)

evaluate_loss(denoiser_net,test_loader, device)



0.0015966438404575456

## color


In [14]:
import os
import time
import numpy as np
from datetime import timedelta
from PIL import Image
from sklearn.model_selection import train_test_split
import torch
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from torch.utils.data import Dataset, DataLoader
from torch.nn import InstanceNorm2d, LeakyReLU, Tanh, Sigmoid
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import deeplay as dl
import matplotlib.pyplot as plt
from glob import glob
import os
gen = dl.UNet2d(in_channels=1, channels=[16,32,64,128], out_channels=3)
# Gradient checkpointing to save memory
try:
    gen.enable_gradient_checkpointing()
except AttributeError:
    pass
# Norm + activations
from torch.nn import LeakyReLU, Tanh
gen['decoder','blocks',:-1].all.normalized(InstanceNorm2d)
gen['decoder','blocks',:-1,'activation'].configure(LeakyReLU, negative_slope=0.2)
gen['decoder','blocks',-1,'activation'].configure(Tanh)
# Build
gen.build().to(device)

# Discriminator
disc = dl.ConvolutionalNeuralNetwork(in_channels=4, hidden_channels=[8,16,32], out_channels=1)
disc['blocks',...,'layer'].configure(kernel_size=4,stride=2,padding=1)
disc['blocks',...,'activation#-1'].configure(LeakyReLU, negative_slope=0.2)
disc['blocks',1:-1].all.normalized(InstanceNorm2d)
disc['blocks',-1,'activation'].configure(Sigmoid)
disc.build().to(device)

# ------------------------
# Losses & optimizers with mixed precision
# ------------------------
loss_disc = torch.nn.MSELoss()
loss_recon = torch.nn.L1Loss()
loss_percep = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(device)
optim_g = torch.optim.Adam(gen.parameters(), lr=2e-4, betas=(0.5,0.999))
optim_d = torch.optim.Adam(disc.parameters(), lr=5e-5, betas=(0.5,0.999))
scaler = torch.amp.GradScaler()


In [15]:
def evaluate_generator_loss(model, loader, device, loss_recon=None, loss_percep=None):
    model.eval()
    recon_loss_fn = loss_recon or torch.nn.L1Loss()
    total_recon = 0.0
    total_percep = 0.0
    count = 0

    with torch.no_grad(), torch.cuda.amp.autocast():
        for gray, rgb in loader:
            gray, rgb = gray.to(device), rgb.to(device)
            fake_rgb = model(gray).clamp(-1, 1)

            recon = recon_loss_fn(fake_rgb, rgb)
            total_recon += recon.item()

            if loss_percep is not None:
                perc = loss_percep(fake_rgb, rgb)
                total_percep += perc.item()

            count += 1

    avg_recon = total_recon / count
    avg_percep = total_percep / count if loss_percep is not None else None

    return avg_recon, avg_percep


In [None]:
import os
from PIL import Image
import torch
import torchvision.transforms as transforms
import numpy as np
from glob import glob

# 0) Load model
ckpt_path = 'color.pth'
device    = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
gen.to(device)
gen.load_state_dict(torch.load(ckpt_path, map_location=device))
gen.eval()

test_paths  = glob('data/test/*.png')

img_size=(256,256)
batch_size=16
test_ds = ColorizationDataset(test_paths, image_size=img_size)
test_loader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, pin_memory=True,  num_workers=0)

percep_loss = LearnedPerceptualImagePatchSimilarity(net_type='vgg').to(device)

# Evaluate
avg_l1, avg_percep = evaluate_generator_loss(gen, test_loader, device, loss_recon=torch.nn.L1Loss(), loss_percep=percep_loss)

print(f"Average L1 Loss: {avg_l1:.4f}")
if avg_percep is not None:
    print(f"Average LPIPS Loss: {avg_percep:.4f}")


  with torch.no_grad(), torch.cuda.amp.autocast():


Average L1 Loss: 0.0861
Average LPIPS Loss: 0.1755
