In [1]:
# pip install rasterio

In [3]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import rasterio
import numpy as np
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Usando dispositivo:", device)

RuntimeError: duplicate registrations for aten.linspace.Tensor_Tensor

In [None]:
# preprocessamento
class SuperResTiffDataset(Dataset):
    def __init__(self, low_res_paths, high_res_paths, target_size=(128,128)):
        self.low_res_paths = low_res_paths
        self.high_res_paths = high_res_paths
        self.target_size = target_size

    def __len__(self):
        return len(self.low_res_paths)

    def __getitem__(self, idx):
        # Lê baixa resolução
        with rasterio.open(self.low_res_paths[idx]) as src:
            low_res = src.read(1).astype(np.float32)

        # Lê alta resolução
        with rasterio.open(self.high_res_paths[idx]) as src:
            high_res = src.read(1).astype(np.float32)

        # Normaliza [0,1]
        low_res = (low_res - np.min(low_res)) / (np.max(low_res) - np.min(low_res) + 1e-8)
        high_res = (high_res - np.min(high_res)) / (np.max(high_res) - np.min(high_res) + 1e-8)

        # Converte para tensores [1,H,W]
        low_res = torch.from_numpy(low_res).unsqueeze(0)
        high_res = torch.from_numpy(high_res).unsqueeze(0)

        # Redimensiona para tamanho alvo
        low_res = F.interpolate(low_res.unsqueeze(0), size=self.target_size, mode="bilinear", align_corners=False).squeeze(0)
        high_res = F.interpolate(high_res.unsqueeze(0), size=self.target_size, mode="bilinear", align_corners=False).squeeze(0)

        return low_res, high_res

In [None]:
def conv_block(in_ch, out_ch):
    return nn.Sequential(
        nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
        nn.ReLU(inplace=True)
    )

class UNetFinal(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_filters=16):
        super().__init__()
        f = base_filters

        # Encoder (c_i são as saídas para as skip connections)
        self.conv1 = conv_block(in_channels, f) # Saída: [B, f, H, W]
        self.conv2 = conv_block(f, f*2) # Saída: [B, f*2, H/2, W/2]
        self.conv3 = conv_block(f*2, f*4) # Saída: [B, f*4, H/4, W/4]
        self.conv4 = conv_block(f*4, f*8) # Saída: [B, f*8, H/8, W/8]

        # Pooling (p_i são as saídas para a próxima camada do encoder)
        self.pool = nn.MaxPool2d(2, 2)

        # Bottleneck
        self.bottleneck = conv_block(f*8, f*16) # Saída: [B, f*16, H/16, W/16]

        # Decoder (u_i são as saídas das camadas de up-sampling)
        # Note a correção nos canais de entrada da conv_up
        self.up4 = nn.ConvTranspose2d(f*16, f*8, kernel_size=2, stride=2)
        self.conv_up4 = conv_block(f*16, f*8) # Concatenado [f*8, f*8]

        self.up3 = nn.ConvTranspose2d(f*8, f*4, kernel_size=2, stride=2)
        self.conv_up3 = conv_block(f*8, f*4) # Concatenado [f*4, f*4]

        self.up2 = nn.ConvTranspose2d(f*4, f*2, kernel_size=2, stride=2)
        self.conv_up2 = conv_block(f*4, f*2) # Concatenado [f*2, f*2]

        self.up1 = nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2)
        self.conv_up1 = conv_block(f*2, f) # Concatenado [f, f]

        # Saída final
        self.final_conv = nn.Conv2d(f, out_channels, kernel_size=1)

    def forward(self, x):
        # Encoder
        c1 = self.conv1(x)  # 128x128
        p1 = self.pool(c1)  # 64x64

        c2 = self.conv2(p1) # 64x64
        p2 = self.pool(c2)  # 32x32

        c3 = self.conv3(p2) # 32x32
        p3 = self.pool(c3)  # 16x16

        c4 = self.conv4(p3) # 16x16
        p4 = self.pool(c4)  # 8x8

        # Bottleneck
        b = self.bottleneck(p4) # 8x8

        # Decoder
        u4 = self.up4(b)    # 16x16
        # Ajuste de tamanho para garantir que a concatenação funcione
        u4 = F.interpolate(u4, size=c4.size()[2:], mode='bilinear', align_corners=False)
        u4 = torch.cat([u4, c4], 1)
        u4 = self.conv_up4(u4)

        u3 = self.up3(u4)   # 32x32
        # Ajuste de tamanho
        u3 = F.interpolate(u3, size=c3.size()[2:], mode='bilinear', align_corners=False)
        u3 = torch.cat([u3, c3], 1)
        u3 = self.conv_up3(u3)

        u2 = self.up2(u3)   # 64x64
        # Ajuste de tamanho
        u2 = F.interpolate(u2, size=c2.size()[2:], mode='bilinear', align_corners=False)
        u2 = torch.cat([u2, c2], 1)
        u2 = self.conv_up2(u2)

        u1 = self.up1(u2)   # 128x128
        # Ajuste de tamanho
        u1 = F.interpolate(u1, size=c1.size()[2:], mode='bilinear', align_corners=False)
        u1 = torch.cat([u1, c1], 1)
        u1 = self.conv_up1(u1)

        return self.final_conv(u1)


In [None]:
# ajuste treinamento
low_res_files = ["recorte_anadem.tif"]
high_res_files = ["recorte_geosampa.tif"]

dataset = SuperResTiffDataset(low_res_files, high_res_files, target_size=(128,128))
loader = DataLoader(dataset, batch_size=1, shuffle=True)

model = UNetFinal().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.MSELoss()

In [None]:
# loop
epochs = 300
for epoch in range(epochs):
    model.train()
    epoch_loss = 0
    for batch_idx, (low_res, high_res) in enumerate(loader):
        low_res, high_res = low_res.to(device), high_res.to(device)

        preds = model(low_res)
        loss = criterion(preds, high_res)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        if batch_idx == 0 and epoch == 0:
            print(f"Low shape: {low_res.shape}, High shape: {high_res.shape}, Preds shape: {preds.shape}")

    print(f"Epoch [{epoch+1}/{epochs}] | Loss: {epoch_loss:.6f}")

torch.save(model.state_dict(), "unet_superres.pth")
print("Treinamento concluído e modelo salvo!")

In [None]:
def gerar_superres_tif(model, low_res_tif, out_path, device):
    model.eval()

    with rasterio.open(low_res_tif) as src:
        img = src.read(1).astype(np.float32)
        profile = src.profile

    img_norm = (img - np.min(img)) / (np.max(img) - np.min(img) + 1e-8)
    img_tensor = torch.from_numpy(img_norm).unsqueeze(0).unsqueeze(0).to(device)

    # Redimensiona para 128x128 se necessário
    img_tensor = F.interpolate(img_tensor, size=(128,128), mode="bilinear", align_corners=False)

    with torch.no_grad():
        pred = model(img_tensor)

    pred = pred.squeeze().cpu().numpy()
    pred = pred * (np.max(img) - np.min(img)) + np.min(img)

    profile.update(dtype=rasterio.float32)
    with rasterio.open(out_path, "w", **profile) as dst:
        dst.write(pred.astype(np.float32), 1)

    print(f"Super-resolução salva em: {out_path}")



In [None]:
# # Exemplo de inferência
gerar_superres_tif(model, "testeeee.tif", "saida_testeee.tif", device)

In [None]:
from rasterio.plot import show
raster = rasterio.open("saida_testeee.tif")
show(raster)