In [None]:
# MIT License
# Copyright (c) 2017 Vooban Inc.
# Coded by: Guillaume Chevalier (original)
# Adapted for PyTorch by [Your Name]

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
import cv2
import numpy as np
from torch.cuda.amp import autocast
import segmentation_models_pytorch as smp

# Global flag to control plotting during processing
PLOT_PROGRESS = True

# Global cache for precomputed window functions
cached_2d_windows = {}

def _spline_window(window_size, power=2):
    """
    Create a 1D squared-spline window.
    """
    n = torch.arange(window_size, dtype=torch.float32)
    if window_size % 2 == 1:
        center = (window_size - 1) / 2
        window = 1 - torch.abs(n - center) / (center + 1)
    else:
        center = (window_size - 1) / 2
        window = 1 - torch.abs(n - center) / (window_size / 2)

    intersection = window_size // 4
    wind_outer = (torch.abs(2 * window) ** power) / 2
    wind_outer[intersection:window_size - intersection] = 0
    wind_inner = 1 - (torch.abs(2 * (window - 1)) ** power) / 2
    wind_inner[:intersection] = 0
    wind_inner[window_size - intersection:] = 0
    wind = wind_inner + wind_outer
    wind = wind / torch.mean(wind)
    return wind

def _window_2D(window_size, power=2):
    """
    Create a 2D window from the 1D spline window (via outer product) and cache it.
    The returned tensor is of shape (1, 1, window_size, window_size) so that it
    can broadcast properly over batches and channel dimensions.
    """
    key = f"{window_size}_{power}"
    if key in cached_2d_windows:
        return cached_2d_windows[key]
    wind1d = _spline_window(window_size, power)
    wind2d = wind1d.unsqueeze(1) * wind1d.unsqueeze(0)  # outer product, shape (window_size, window_size)
    wind2d = wind2d.unsqueeze(0).unsqueeze(0)  # shape: (1, 1, window_size, window_size)
    if PLOT_PROGRESS:
        plt.imshow(wind2d.squeeze().cpu().numpy(), cmap="viridis")
        plt.title("2D Windowing Function for a Smooth Blending of Overlapping Patches")
        plt.show()
    cached_2d_windows[key] = wind2d
    return wind2d

def _pad_img(img, window_size, subdivisions):
    """
    Pads the input image (tensor of shape (C, H, W)) using reflection.
    """
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    # F.pad uses the pad tuple in the order (left, right, top, bottom)
    padded_img = F.pad(img, (aug, aug, aug, aug), mode='reflect')
    if PLOT_PROGRESS:
        plt.imshow(padded_img.permute(1, 2, 0).cpu().numpy())
        plt.title("Padded Image for Using Tiled Prediction Patches (reflect)")
        plt.show()
    return padded_img

def _unpad_img(padded_img, window_size, subdivisions):
    """
    Remove the extra padding added by _pad_img.
    """
    aug = int(round(window_size * (1 - 1.0/subdivisions)))
    return padded_img[:, aug:-aug, aug:-aug]

def _rotate_mirror_do(im):
    """
    Create the 8 transformations (rotations and horizontal flips) of the image.
    Input is assumed to be a tensor of shape (C, H, W).
    """
    transforms = []
    transforms.append(im.clone())
    transforms.append(torch.rot90(im, k=1, dims=(1, 2)))
    transforms.append(torch.rot90(im, k=2, dims=(1, 2)))
    transforms.append(torch.rot90(im, k=3, dims=(1, 2)))
    im_flipped = torch.flip(im, dims=[2])  # horizontal flip
    transforms.append(im_flipped)
    transforms.append(torch.rot90(im_flipped, k=1, dims=(1, 2)))
    transforms.append(torch.rot90(im_flipped, k=2, dims=(1, 2)))
    transforms.append(torch.rot90(im_flipped, k=3, dims=(1, 2)))
    return transforms

def _rotate_mirror_undo(im_mirrs):
    """
    Invert the 8 transformations and average the results.
    """
    originals = []
    originals.append(im_mirrs[0])
    originals.append(torch.rot90(im_mirrs[1], k=3, dims=(1, 2)))
    originals.append(torch.rot90(im_mirrs[2], k=2, dims=(1, 2)))
    originals.append(torch.rot90(im_mirrs[3], k=1, dims=(1, 2)))
    originals.append(torch.flip(im_mirrs[4], dims=[2]))
    originals.append(torch.flip(torch.rot90(im_mirrs[5], k=3, dims=(1, 2)), dims=[2]))
    originals.append(torch.flip(torch.rot90(im_mirrs[6], k=2, dims=(1, 2)), dims=[2]))
    originals.append(torch.flip(torch.rot90(im_mirrs[7], k=1, dims=(1, 2)), dims=[2]))
    return torch.stack(originals, dim=0).mean(dim=0)

def _windowed_subdivs(padded_img, window_size, subdivisions, nb_classes, pred_func):
    """
    Extract overlapping patches from the padded image, run prediction on the batch,
    and weight each patch with a 2D window.
    """
    WINDOW_SPLINE_2D = _window_2D(window_size, power=2).to(padded_img.device)
    step = window_size // subdivisions
    C, H, W = padded_img.shape
    patches = []
    for i in range(0, H - window_size + 1, step):
        row_patches = []
        for j in range(0, W - window_size + 1, step):
            patch = padded_img[:, i:i+window_size, j:j+window_size]
            row_patches.append(patch)
        patches.append(row_patches)
    num_rows = len(patches)
    num_cols = len(patches[0])
    patches_tensor = torch.stack([p for row in patches for p in row], dim=0)
    # pred_func should accept a tensor of shape (B, C, window_size, window_size)
    pred = pred_func(patches_tensor)  # expected output: (B, nb_classes, window_size, window_size)
    pred = pred * WINDOW_SPLINE_2D  # apply window weighting
    pred = pred.view(num_rows, num_cols, nb_classes, window_size, window_size)
    return pred

def _recreate_from_subdivs(subdivs, window_size, subdivisions, padded_out_shape):
    """
    Merge the weighted overlapping patches back into a full image.
    """
    step = window_size // subdivisions
    nb_classes, H, W = padded_out_shape
    y = torch.zeros(padded_out_shape, device=subdivs.device)
    num_rows = subdivs.shape[0]
    num_cols = subdivs.shape[1]
    row_idx = 0
    for i in range(0, H - window_size + 1, step):
        col_idx = 0
        for j in range(0, W - window_size + 1, step):
            y[:, i:i+window_size, j:j+window_size] += subdivs[row_idx, col_idx]
            col_idx += 1
        row_idx += 1
    y = y / (subdivisions ** 2)
    return y

def predict_img_with_smooth_windowing_batched(input_img, window_size, subdivisions, nb_classes, pred_func, batch_size=16):
    """
    Processa a imagem de entrada (tensor com shape (C, H, W)) usando smooth blending,
    mas processa os patches em batches menores para evitar estourar a memória.

    Args:
        input_img: tensor da imagem completa (C, H, W).
        window_size: tamanho do patch original (ex.: 256).
        subdivisions: fator de subdivisão (determina a sobreposição).
        nb_classes: número de canais de saída do modelo.
        pred_func: função que realiza a inferência em um batch de patches.
                   Espera input de shape (B, C, window_size, window_size) e retorna
                   (B, nb_classes, window_size, window_size).
        batch_size: quantidade de patches processados de cada vez.

    Retorna:
        Tensor com a predição final (nb_classes, H, W).
    """
    # Aplica padding à imagem
    pad = _pad_img(input_img, window_size, subdivisions)
    # Gera as 8 rotações/mirrors
    rotations = _rotate_mirror_do(pad)
    results = []

    # Para cada rotação, extraímos os patches e os processamos em batches
    for rot_img in rotations:
        C, H, W = rot_img.shape
        step = window_size // subdivisions
        patches = []
        coords = []
        # Extrai os patches com sobreposição
        for i in range(0, H - window_size + 1, step):
            for j in range(0, W - window_size + 1, step):
                patch = rot_img[:, i:i+window_size, j:j+window_size]
                patches.append(patch)
                coords.append((i, j))

        # Processa os patches em batches menores
        preds = []
        for i in range(0, len(patches), batch_size):
            batch = torch.stack(patches[i: i+batch_size], dim=0)  # shape: (B, C, window_size, window_size)
            batch_pred = pred_func(batch)  # deve retornar shape: (B, nb_classes, window_size, window_size)
            preds.append(batch_pred)
        preds = torch.cat(preds, dim=0)

        # Recria a imagem de predição a partir dos patches processados
        pred_img = torch.zeros(nb_classes, H, W, device=input_img.device)
        weight_img = torch.zeros(nb_classes, H, W, device=input_img.device)
        # Cria a janela 2D para blending (assumindo que _window_2D retorna shape (1,1,window_size,window_size))
        window = _window_2D(window_size, power=2).to(input_img.device).squeeze(0).squeeze(0)

        idx = 0
        for (i, j) in coords:
            pred_img[:, i:i+window_size, j:j+window_size] += preds[idx] * window
            weight_img[:, i:i+window_size, j:j+window_size] += window
            idx += 1

        # Normaliza pela soma dos pesos
        pred_img = pred_img / weight_img
        results.append(pred_img)

    # Faz o merge das 8 rotações
    final_pred = _rotate_mirror_undo(results)
    # Remove o padding
    final_pred = _unpad_img(final_pred, window_size, subdivisions)
    return final_pred


def cheap_tiling_prediction(img, window_size, nb_classes, pred_func):
    """
    Run prediction on non-overlapping patches (with padding to nearest window size multiple)
    and merge the results.
    """
    C, H, W = img.shape
    full_border_h = H if H % window_size == 0 else H + (window_size - (H % window_size))
    full_border_w = W if W % window_size == 0 else W + (window_size - (W % window_size))
    prd = torch.zeros((nb_classes, full_border_h, full_border_w), device=img.device)
    tmp = torch.zeros((C, full_border_h, full_border_w), device=img.device)
    tmp[:, :H, :W] = img
    img = tmp
    for i in range(0, full_border_h, window_size):
        for j in range(0, full_border_w, window_size):
            patch = img[:, i:i+window_size, j:j+window_size]
            patch_pred = pred_func(patch.unsqueeze(0))  # shape: (1, nb_classes, window_size, window_size)
            prd[:, i:i+window_size, j:j+window_size] = patch_pred.squeeze(0)
    prd = prd[:, :H, :W]
    if PLOT_PROGRESS:
        plt.imshow(prd.permute(1, 2, 0).cpu().numpy())
        plt.title("Cheaply Merged Patches")
        plt.show()
    return prd

def get_dummy_img(xy_size=128, nb_channels=3):
    """
    Create a random image (tensor of shape (C, H, W)) with spatial variation.
    """
    img = torch.rand(nb_channels, xy_size, xy_size)
    img = img + torch.ones((1, xy_size, xy_size))
    lin = torch.linspace(0, 1, xy_size)
    grid = lin.unsqueeze(0) * lin.unsqueeze(1)  # outer product, shape (xy_size, xy_size)
    img = img * grid.unsqueeze(0)
    img = img + torch.flip(img, dims=[1, 2])
    img = img - img.min()
    img = img / img.max() / 2
    if PLOT_PROGRESS:
        plt.imshow(img.permute(1, 2, 0).cpu().numpy())
        plt.title("Random image for a test")
        plt.show()
    return img

def round_predictions(prd, nb_channels_out, thresholds):
    """
    Binarize each channel in the predictions based on the corresponding threshold.
    """
    for i in range(nb_channels_out):
        prd[i] = (prd[i] > thresholds[i]).float()
    return prd

In [None]:

def infer_and_save_full_image_smooth(model, device, image_path, output_path, nb_classes, window_size=256, subdivisions=2):
    """
    Processa uma imagem completa (por exemplo, 6200x4000) usando smooth blending.
    A imagem é dividida em patches de tamanho window_size (256×256), cada patch é upscalado para 512×512
    para a inferência (pois o modelo espera esse tamanho) e, após a predição, os resultados são
    downscale de volta para 256×256 e combinados suavemente para formar a máscara final.

    Args:
        model: modelo PyTorch para inferência.
        device: dispositivo ('cpu' ou 'cuda').
        image_path: caminho para a imagem completa.
        output_path: caminho para salvar a máscara predita.
        nb_classes: número de classes de saída do modelo.
        window_size: tamanho do patch original (default: 256).
        subdivisions: fator de sobreposição para a função de blending (default: 2).
    """
    # Carrega a imagem completa
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Erro ao carregar a imagem: {image_path}")

    # Converte de BGR para RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Pré-processa a imagem (ex.: normalização) conforme seu modelo requer
    image = image.astype('float32')
    image = preprocess_input(image)  # Função definida externamente

    # Converte para tensor no formato (C, H, W)
    image_tensor = torch.tensor(image).permute(2, 0, 1).to(device).float()

    def pred_func(patches):
        #patches_up = F.interpolate(patches, size=(512, 512), mode='bilinear', align_corners=False)
        with torch.no_grad(), autocast():
            output = model(patches)
        output = torch.softmax(output, dim=1)
        output_down = F.interpolate(output, size=(256, 256), mode='nearest')
        return output_down

    # Aplica a função de smooth windowing para processar a imagem inteira.
    # A função 'predict_img_with_smooth_windowing' deve estar adaptada para tensores (canais-first)
    prediction = predict_img_with_smooth_windowing_batched(
        image_tensor, window_size, subdivisions, nb_classes, pred_func
    )
    # prediction: tensor de shape (nb_classes, H, W) correspondente à imagem completa

    # Converte a predição para numpy e aplica argmax para obter a máscara final
    prediction_np = prediction.cpu().numpy()
    mask = np.argmax(prediction_np, axis=0)  # shape: (H, W)

    # Mapeia os valores: ex. classe 0 -> 0, classe 1 -> 127, classe 2 -> 255 (ajuste se necessário)
    mask = np.where(mask == 1, 127, np.where(mask == 2, 255, 0)).astype(np.uint8)

    # Salva a máscara predita
    cv2.imwrite(output_path, mask)
    print(f"Salvo: {output_path}")

In [None]:
preprocess_input = smp.encoders.get_preprocessing_fn('timm-efficientnet-b8', pretrained='imagenet') # Obtemos a função de pré-processamento para o EfficientNet-B8 treinado em ImageNet
device = 'cuda' if torch.cuda.is_available else 'cpu' # Escolhe o dispositivo onde vão correr os tensores
model = smp.Unet(encoder_name='timm-efficientnet-b8',encoder_weights='imagenet',in_channels=3,classes=3) # Instancia uma Unet com encoder EfficientNet-B8 pré-treinado em ImageNet
model.load_state_dict(torch.load("uNetB8CombinedCP_BlurPequeno7.pth")) # Carrega os pesos gravados durante treino a partir do ficheiro .pth
model.to(device) # Move o modelo para o dispositivo selecionado (GPU ou CPU)
model.eval() # Passa o modelo para modo de inferência

In [None]:
infer_and_save_full_image_smooth(
    model,
    device,
    image_path='inputImage.png',
    output_path='outputMask.png',
    nb_classes=3
)