In [None]:
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from pathlib import Path
import extrair_zip_train_dir as zipService


class ImageStitchingDatasetFiles(Dataset):
    def __init__(self, folder_path, use_gradiente=False):
        self.folder = Path(folder_path)
        self.use_gradiente = use_gradiente
        # Lista todos arquivos .pt ordenados
        self.files = sorted(self.folder.glob("*.pt"))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        sample = torch.load(self.files[idx])

        def to_float_tensor(t):
            # uint8 [0..255] -> float32 [0..1]
            return t.float() / 255.0

        parte1 = to_float_tensor(sample["parte1"])
        parte2 = to_float_tensor(sample["parte2"])
        groundtruth = to_float_tensor(sample["groundtruth"])

        if self.use_gradiente:
            gradiente = to_float_tensor(sample["gradiente"])
            return (parte1, parte2), groundtruth, gradiente
        else:
            return (parte1, parte2), groundtruth

zipService.descompactar_zip_com_progresso("./train.zip", "./train")
dataset = ImageStitchingDatasetFiles("./train", use_gradiente=False)
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4, prefetch_factor=2, pin_memory=True)


32 x 48
| Bloco                    | Altura × Largura | Canais p/ encoder | Canais Pós-concatenação |
|--------------------------|------------------|--------------------|--------------------------|
| Entrada                  | 32×48            | 3                  | —                        |
| `enc1`                   | 32×48            | 32                 | 64 (concat)              |
| `pool1`                  | 16×24            | 32                 | 64 (concat)              |
| `enc2`                   | 16×24            | 64                 | 128 (concat)             |
| `pool2`                  | 8×12             | 64                 | 128 (concat)             |
| Bottleneck (concat)      | 8×12             | —                  | 256                      |
| `dec2` entrada           | 8×12             | 256 + 128 = 384    | —                        |
| `dec2` saída             | 16×24            | 64                 | —                        |
| `dec1` entrada           | 16×24            | 64 + 64 = 128      | —                        |
| `dec1` saída             | 32×48            | 32                 | —                        |
| Saída final              | 32×48            | 3                  | —                        |

64x96
| Bloco                    | Altura × Largura | Canais p/ encoder | Canais Pós-concatenação |
|--------------------------|------------------|--------------------|--------------------------|
| Entrada                  | 64×96            | 3                  | —                        |
| `enc1`                   | 64×96            | 32                 | 64 (concat)              |
| `pool1`                  | 32×48            | 32                 | 64 (concat)              |
| `enc2`                   | 32×48            | 64                 | 128 (concat)             |
| `pool2`                  | 16×24            | 64                 | 128 (concat)             |
| Bottleneck (concat)      | 16×24            | —                  | 256                      |
| `dec2` entrada           | 16×24            | 256 + 128 = 384    | —                        |
| `dec2` saída             | 32×48            | 64                 | —                        |
| `dec1` entrada           | 32×48            | 64 + 64 = 128      | —                        |
| `dec1` saída             | 64×96            | 32                 | —                        |
| Saída final              | 64×96            | 3                  | —                        |


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# CBAM (Convolutional Block Attention Module)
# Aplica atenção canal + espacial separadamente
class CBAM(nn.Module):
    def __init__(self, channels, reduction=16):
        super(CBAM, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(
            nn.Conv2d(channels, channels // reduction, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channels // reduction, channels, 1, bias=False)
        )

        self.sigmoid_channel = nn.Sigmoid()
        self.conv_spatial = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid_spatial = nn.Sigmoid()

    def forward(self, x):
        # Atenção no canal
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        x = x * self.sigmoid_channel(avg_out + max_out)

        # Atenção espacial
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = x * self.sigmoid_spatial(self.conv_spatial(x))
        return x

# Self-Attention simples no bottleneck
class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.key = nn.Conv2d(in_dim, in_dim // 8, 1)
        self.value = nn.Conv2d(in_dim, in_dim, 1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        proj_query = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
        proj_key = self.key(x).view(B, -1, H * W)
        energy = torch.bmm(proj_query, proj_key)  # matriz de atenção
        attention = F.softmax(energy, dim=-1)

        proj_value = self.value(x).view(B, -1, H * W)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(B, C, H, W)
        return self.gamma * out + x

# Bloco de codificação padrão
class EncoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

# Bloco de decodificação com upsample + concat + convoluções
class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.ConvTranspose2d(in_ch, out_ch, kernel_size=2, stride=2)
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x, skip):
        x = self.up(x)
        if x.shape[2:] != skip.shape[2:]:
            x = F.interpolate(x, size=skip.shape[2:], mode='bilinear', align_corners=False)
        x = torch.cat([x, skip], dim=1)  # concatena skip connection
        return self.conv(x)

# Rede UNet com dois encoders, CBAM e self-attention no bottleneck
class DualEncoderUNet_CBAM_SA_Small(nn.Module):
    def __init__(self, in_channels=3, base_ch=32):
        super().__init__()

        # Dois encoders independentes (parte1 e parte2)
        self.enc1_1 = EncoderBlock(in_channels, base_ch)
        self.enc2_1 = EncoderBlock(base_ch, base_ch * 2)

        self.enc1_2 = EncoderBlock(in_channels, base_ch)
        self.enc2_2 = EncoderBlock(base_ch, base_ch * 2)

        self.pool = nn.MaxPool2d(2)

        # Bottleneck com self-attention
        self.bottleneck = EncoderBlock(base_ch * 4, base_ch * 4)
        self.attn = SelfAttention(base_ch * 4)

        # CBAM nas skip connections
        self.cbam2 = CBAM(base_ch * 4)  # corrigido para concatenação dos dois caminhos
        self.cbam1 = CBAM(base_ch * 2)

        # Decoder reduzido (dois níveis), corrigido para lidar com concatenações
        self.dec2 = DecoderBlock(base_ch * 8, base_ch * 2)  # 4 (bottleneck) + 4 (skip)
        self.dec1 = DecoderBlock(base_ch * 4, base_ch)      # 2 + 2

        self.final = nn.Conv2d(base_ch, 3, kernel_size=1)

    def forward(self, x1, x2):
        # Encoder parte 1
        e1_1 = self.enc1_1(x1)
        e2_1 = self.enc2_1(self.pool(e1_1))

        # Encoder parte 2
        e1_2 = self.enc1_2(x2)
        e2_2 = self.enc2_2(self.pool(e1_2))

        # Garantir que as features tenham mesmo tamanho
        if e1_1.shape[2:] != e1_2.shape[2:]:
            e1_2 = F.interpolate(e1_2, size=e1_1.shape[2:], mode='bilinear', align_corners=False)
        if e2_1.shape[2:] != e2_2.shape[2:]:
            e2_2 = F.interpolate(e2_2, size=e2_1.shape[2:], mode='bilinear', align_corners=False)

        # Bottleneck com concatenação e self-attention
        b = self.bottleneck(torch.cat([self.pool(e2_1), self.pool(e2_2)], dim=1))
        b = self.attn(b)

        # Decoder com CBAM nas skip connections (corrigido para canais concatenados)
        d2 = self.dec2(b, self.cbam2(torch.cat([e2_1, e2_2], dim=1)))
        d1 = self.dec1(d2, self.cbam1(torch.cat([e1_1, e1_2], dim=1)))

        return torch.sigmoid(self.final(d1))  # saída com sigmoid (valores entre 0 e 1)


In [None]:
import os
import time
import torch
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid
from torchvision.transforms import ToPILImage
from datetime import datetime

from generator import DualEncoderUNet_CBAM_SA_Small
from discriminator import Discriminator
from losses import GANLoss, compute_gradient_penalty
from metrics import compute_all_metrics

# === Hiperparâmetros ===
min_lr = 1e-5
max_lr = 2e-4
gen_steps = 20


def train(dataloader, device, num_epochs=100, log_interval=600):
    # Diretórios
    os.makedirs("checkpoints_epoch", exist_ok=True)
    os.makedirs("checkpoints_batch", exist_ok=True)
    logdir = os.path.join("runs", datetime.now().strftime("%Y%m%d-%H%M%S"))
    writer = SummaryWriter(logdir)

    # Modelos
    G = DualEncoderUNet_CBAM_SA_Small().to(device)
    D = Discriminator().to(device)

    # Otimizadores com LR inicial máximo
    opt_G = torch.optim.Adam(G.parameters(), lr=max_lr, betas=(0.5, 0.999))
    opt_D = torch.optim.Adam(D.parameters(), lr=max_lr, betas=(0.5, 0.999))

    # Critério adversarial
    gan_loss = GANLoss().to(device)

    # Timer
    last_log_time = time.time()

    global_step = 0
    for epoch in range(num_epochs):
        for batch in dataloader:
            G.train(); D.train()
            (p1, p2), gt = batch
            p1, p2, gt = p1.to(device), p2.to(device), gt.to(device)

            # === Treinar Discriminador ===
            with torch.no_grad():
                fake = G(p1, p2).detach()
            real_input = torch.cat([p1, p2, gt], dim=1)
            fake_input = torch.cat([p1, p2, fake], dim=1)

            pred_real = D(real_input)
            pred_fake = D(fake_input)

            loss_D = gan_loss.discriminator_loss(pred_real, pred_fake)
            gp = compute_gradient_penalty(D, real_input, fake_input, device)
            loss_D_total = loss_D + 10 * gp

            opt_D.zero_grad()
            loss_D_total.backward()
            opt_D.step()

            # === Treinar Gerador ===
            for _ in range(gen_steps):
                fake = G(p1, p2)
                fake_input = torch.cat([p1, p2, fake], dim=1)

                pred_fake = D(fake_input)
                loss_G_GAN = gan_loss.generator_loss(pred_fake)
                loss_G_L1 = nn.L1Loss()(fake, gt)
                loss_G = 8 * loss_G_GAN + 2 * loss_G_L1

                opt_G.zero_grad()
                loss_G.backward()
                opt_G.step()

            # === Ajuste de LR (cosine decay) ===
            progress = epoch / num_epochs
            lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + torch.cos(torch.tensor(progress * 3.1416)))
            for param_group in opt_G.param_groups:
                param_group['lr'] = lr.item()
            for param_group in opt_D.param_groups:
                param_group['lr'] = lr.item()

            # === TensorBoard ===
            writer.add_scalar("Loss/Discriminator", loss_D.item(), global_step)
            writer.add_scalar("Loss/Generator", loss_G.item(), global_step)
            writer.add_scalar("Loss/L1", loss_G_L1.item(), global_step)
            writer.add_scalar("Loss/GAN", loss_G_GAN.item(), global_step)
            writer.add_scalar("GP", gp.item(), global_step)
            writer.add_scalar("LR", lr.item(), global_step)

            # === Métricas ===
            metrics = compute_all_metrics(fake, gt)
            for k, v in metrics.items():
                writer.add_scalar(f"Metrics/{k}", v, global_step)

            # === Visualização ===
            if global_step % 50 == 0:
                grid = make_grid(torch.cat([p1, p2, fake, gt], dim=0), nrow=p1.size(0))
                writer.add_image("Comparison", grid, global_step)

            # === Checkpoint por tempo ===
            if time.time() - last_log_time > log_interval:
                torch.save(G.state_dict(), f"checkpoints_batch/G_step{global_step}.pt")
                torch.save(D.state_dict(), f"checkpoints_batch/D_step{global_step}.pt")
                last_log_time = time.time()

            global_step += 1

        # === Checkpoint por época ===
        torch.save(G.state_dict(), f"checkpoints_epoch/G_epoch{epoch}.pt")
        torch.save(D.state_dict(), f"checkpoints_epoch/D_epoch{epoch}.pt")

    writer.close()
