In [1]:
import os
import random
from abc import abstractmethod

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms.functional as F1
from PIL import Image
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset
from torchvision.transforms import ToTensor, transforms
from tqdm import tqdm
from types_ import *
from torchmetrics import MultiScaleStructuralSimilarityIndexMeasure
from pytorch_msssim import ssim, ms_ssim, SSIM, MS_SSIM
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim_skimage

In [2]:
class BaseVAE(nn.Module):

    def __init__(self) -> None:
        super(BaseVAE, self).__init__()

    def encode(self, input: Tensor) -> List[Tensor]:
        raise NotImplementedError

    def decode(self, input: Tensor) -> Any:
        raise NotImplementedError

    def sample(self, batch_size: int, current_device: int, **kwargs) -> Tensor:
        raise NotImplementedError

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        raise NotImplementedError

    @abstractmethod
    def forward(self, *inputs: Tensor) -> Tensor:
        pass

    @abstractmethod
    def loss_function(self, *inputs: Any, **kwargs) -> Tensor:
        pass

In [3]:
class PairedTopoDatasetImg(Dataset):
    def __init__(self, image_dir, topo_dir, transform):
        """
        Args:
            image_dir (str): Папка с обычными изображениями (без `_topo`).
            topo_dir (str): Папка с топо-изображениями (с `_topo`).
            transform (callable, optional): Дополнительные аугментации.
        """
        self.image_dir = image_dir
        self.topo_dir = topo_dir
        self.transform = transform
        self.to_tensor = ToTensor()  # Конвертирует PIL → Tensor и нормализует в [0, 1]

        # Собираем все файлы и находим пары (image, topo)
        self.pairs = self._find_paired_files()

    def _extract_base(self, filename):
        """Извлекает базовое имя и номер из названия файла.
        Пример:
            "cor_000_Pomona_..._0.png" → ("cor_000_Pomona_...", 0)
            "cor_000_Pomona_..._topo_5.png" → ("cor_000_Pomona_...", 5)
        """
        # Удаляем расширение (.png, .jpg и т.д.)
        base_part = os.path.splitext(filename)[0]

        # Если в названии есть `_topo`, убираем его из базовой части
        if "_topo" in base_part:
            base_part = base_part.replace("_topo", "")

        return base_part

    def _find_paired_files(self):
        """Находит все пары (image, topo) с одинаковыми base_name и номером."""
        # Собираем все файлы из image_dir и topo_dir
        image_files = os.listdir(self.image_dir)
        topo_files = os.listdir(self.topo_dir)

        # Создаем словарь: { (base_name, number) → {"image": path, "topo": path} }
        pairs_dict = {}

        # Обрабатываем обычные изображения (без `_topo`)
        for img_file in image_files:
            base = self._extract_base(img_file)
            if base is None:
                continue  # Пропускаем файлы с неправильным форматом
            key = base
            if key not in pairs_dict:
                pairs_dict[key] = {"image": None, "topo": None}
            pairs_dict[key]["image"] = os.path.join(self.image_dir, img_file)

        # Обрабатываем топо-изображения (с `_topo`)
        for topo_file in topo_files:
            base = self._extract_base(topo_file)
            if base is None:
                continue
            key = base
            if key not in pairs_dict:
                continue  # Нет пары в image_dir → пропускаем
            pairs_dict[key]["topo"] = os.path.join(self.topo_dir, topo_file)

        # Оставляем только полные пары (где есть и image, и topo)
        valid_pairs = []
        for key in pairs_dict:
            if pairs_dict[key]["image"] and pairs_dict[key]["topo"]:
                img_name = os.path.basename(pairs_dict[key]["image"])
                topo_name = os.path.basename(pairs_dict[key]["topo"])
                valid_pairs.append(
                    (
                        pairs_dict[key]["image"],  # путь к обычному изображению
                        pairs_dict[key]["topo"],  # путь к топо-изображению
                        img_name,  # имя файла обычного изображения
                        topo_name,  # имя файла топо-изображения
                    )
                )
        return valid_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, topo_path, img_name, topo_name = self.pairs[idx]
    
    # Загружаем изображения
        img = Image.open(img_path).convert("L")
        topo = Image.open(topo_path).convert("L")

        # Конвертируем в numpy и нормализуем
        img = np.array(img).astype(np.float32) / 255
        topo = np.array(topo).astype(np.float32) / 255

        # Преобразуем в тензоры [1, H, W]
        img_tensor = self.to_tensor(img)
        topo_tensor = self.to_tensor(topo)
        
        # Применяем аугментации
        if self.transform:
            img_tensor = self.transform(img_tensor)
            topo_tensor = self.transform(topo_tensor)

        return img_tensor, img_name

In [4]:
class PairedTopoDatasetTopo(Dataset):
    def __init__(self, image_dir, topo_dir, transform):
        """
        Args:
            image_dir (str): Папка с обычными изображениями (без `_topo`).
            topo_dir (str): Папка с топо-изображениями (с `_topo`).
            transform (callable, optional): Дополнительные аугментации.
        """
        self.image_dir = image_dir
        self.topo_dir = topo_dir
        self.transform = transform
        self.to_tensor = ToTensor()
        self.pairs = self._find_paired_files()

    def _extract_base_name(self, filename):
        """Извлекает базовое имя из названия файла (без _topo и расширения)"""
        name = os.path.splitext(filename)[0]
        return name.replace("_topo", "")

    def _find_paired_files(self):
        """Находит пары (image, topo) с одинаковыми базовыми именами"""
        image_files = {self._extract_base_name(f): f for f in os.listdir(self.image_dir)}
        topo_files = {self._extract_base_name(f): f for f in os.listdir(self.topo_dir)}
        
        # Находим общие базовые имена
        common_bases = set(image_files.keys()) & set(topo_files.keys())
        
        # Создаем пары
        valid_pairs = []
        for base in common_bases:
            img_path = os.path.join(self.image_dir, image_files[base])
            topo_path = os.path.join(self.topo_dir, topo_files[base])
            valid_pairs.append((
                img_path,
                topo_path,
                image_files[base],  # имя обычного файла
                topo_files[base]    # имя топо-файла
            ))
        
        return valid_pairs

    def __len__(self):
        return len(self.pairs)

    def __getitem__(self, idx):
        img_path, topo_path, img_name, topo_name = self.pairs[idx]
        
        # Загружаем изображения
        img = Image.open(img_path).convert("L")
        topo = Image.open(topo_path).convert("L")

        # Уменьшаем разрешение в 50 раз
        new_size = (int(img.width / 50), int(img.height / 50))
        img = img.resize(new_size, Image.Resampling.LANCZOS)
        topo = topo.resize(new_size, Image.Resampling.LANCZOS)

        # Конвертируем в numpy и нормализуем
        img = np.array(img).astype(np.float32) / 255
        topo = np.array(topo).astype(np.float32) / 255

        # Преобразуем в тензоры [1, H, W]
        img_tensor = self.to_tensor(img)
        topo_tensor = self.to_tensor(topo)
        
        # Применяем аугментации
        if self.transform:
            img_tensor = self.transform(img_tensor)
            topo_tensor = self.transform(topo_tensor)

        return topo_tensor, topo_name

In [5]:
class VanillaVAEImg(BaseVAE):

    def __init__(
        self, in_channels: int, latent_dim: int, hidden_dims: List = None, **kwargs
    ) -> None:
        super(VanillaVAEImg, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [32, 64, 128, 256, 512]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim
            # in_channels = 2

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1] * 16 * 16, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1] * 16 * 16, latent_dim)

        # Build Decoder
        modules = []

        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 16 * 16)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_dims[-1],
                hidden_dims[-1],
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=1,
            ),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=1, kernel_size=3, padding=1),
        )

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        # print(f"Encoder output shape: {result.shape}")
        result = torch.flatten(result, start_dim=1)
        # print(f"Flattened shape: {result.shape}")

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 512, 16, 16)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return [self.decode(z), input, mu, log_var]

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        def mssim_loss(recons, target):
            return 1 - ms_ssim(recons, target, data_range=1.0, size_average=True)

        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        log_var = torch.clamp(log_var, min=-10, max=10)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )
        kld_weight = 0.0001

        recons_loss = F.mse_loss(recons, input[:,0,:,:].unsqueeze(1))
        # recons_loss = F.mse_loss(recons, input[:,0,:,:])
        recons_weight = 0.8
        
        mssim = mssim_loss(recons, input[:, 0, :, :].unsqueeze(1))
        # mssim = mssim_loss(recons, input[:, 0, :, :])
        mssim_weight = 1 - recons_weight

        loss = recons_loss * recons_weight + kld_loss * kld_weight + mssim * mssim_weight
        
        # loss = recons_loss * recons_weight
        return {"loss": loss, "Reconstruction_Loss": recons_loss, "KLD": kld_loss, "MSSIM": mssim}
        # return {"loss": loss, "Reconstruction_Loss": recons_loss}

    def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [6]:
class VanillaVAETopo(BaseVAE):

    def __init__(
        self, in_channels: int, latent_dim: int, hidden_dims: List = None, **kwargs
    ) -> None:
        super(VanillaVAETopo, self).__init__()

        self.latent_dim = latent_dim

        modules = []
        if hidden_dims is None:
            hidden_dims = [16, 32, 64]

        # Build Encoder
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(
                        in_channels,
                        out_channels=h_dim,
                        kernel_size=3,
                        stride=2,
                        padding=1,
                    ),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU(),
                )
            )
            in_channels = h_dim
            # in_channels = 2

        self.encoder = nn.Sequential(*modules)
        self.fc_mu = nn.Linear(hidden_dims[-1] * 2 * 2, latent_dim)
        self.fc_var = nn.Linear(hidden_dims[-1] * 2 * 2, latent_dim)

        # Build Decoder
        modules = []
               
        self.decoder_input = nn.Linear(latent_dim, hidden_dims[-1] * 2 * 2)

        hidden_dims.reverse()

        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(
                        hidden_dims[i],
                        hidden_dims[i + 1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1 if i < len(hidden_dims)-2 else 0
                    ),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU(),
                )
            )

        self.decoder = nn.Sequential(*modules)

        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(
                hidden_dims[-1],
                hidden_dims[-1],
                kernel_size=3,
                stride=2,
                padding=1,
                output_padding=0,
            ),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], out_channels=1, kernel_size=3, padding=1),
            nn.Upsample(size=(10, 10), mode='bilinear')  # Добавляем финальный resize до 10×10
        )

    def encode(self, input: Tensor) -> List[Tensor]:
        """
        Encodes the input by passing through the encoder network
        and returns the latent codes.
        :param input: (Tensor) Input tensor to encoder [N x C x H x W]
        :return: (Tensor) List of latent codes
        """
        result = self.encoder(input)
        # print(f"Encoder output shape: {result.shape}")
        result = torch.flatten(result, start_dim=1)
        # print(f"Flattened shape: {result.shape}")

        # Split the result into mu and var components
        # of the latent Gaussian distribution
        mu = self.fc_mu(result)
        log_var = self.fc_var(result)

        return [mu, log_var]

    def decode(self, z: Tensor) -> Tensor:
        """
        Maps the given latent codes
        onto the image space.
        :param z: (Tensor) [B x D]
        :return: (Tensor) [B x C x H x W]
        """
        result = self.decoder_input(z)
        result = result.view(-1, 64, 2, 2)
        result = self.decoder(result)
        result = self.final_layer(result)
        return result

    def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
        """
        Reparameterization trick to sample from N(mu, var) from
        N(0,1).
        :param mu: (Tensor) Mean of the latent Gaussian [B x D]
        :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
        :return: (Tensor) [B x D]
        """
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return eps * std + mu

    def forward(self, input: Tensor, **kwargs) -> List[Tensor]:
        mu, log_var = self.encode(input)
        z = self.reparameterize(mu, log_var)
        return [self.decode(z), input, mu, log_var]

    def loss_function(self, *args, **kwargs) -> dict:
        """
        Computes the VAE loss function.
        KL(N(\mu, \sigma), N(0, 1)) = \log \frac{1}{\sigma} + \frac{\sigma^2 + \mu^2}{2} - \frac{1}{2}
        :param args:
        :param kwargs:
        :return:
        """
        def mssim_loss(recons, target):
            return 1 - ms_ssim(recons, target, data_range=1.0, size_average=True)

        recons = args[0]
        input = args[1]
        mu = args[2]
        log_var = args[3]
        log_var = torch.clamp(log_var, min=-10, max=10)

        kld_loss = torch.mean(
            -0.5 * torch.sum(1 + log_var - mu**2 - log_var.exp(), dim=1), dim=0
        )
        kld_weight = 0.0001

        recons_loss = F.mse_loss(recons, input[:,0,:,:].unsqueeze(1))
        recons_weight = 0.99

        loss = recons_loss * recons_weight + kld_loss * kld_weight
        
        return {"loss": loss, "Reconstruction_Loss": recons_loss, "KLD": kld_loss}

    def sample(self, num_samples: int, current_device: int, **kwargs) -> Tensor:
        """
        Samples from the latent space and return the corresponding
        image space map.
        :param num_samples: (Int) Number of samples
        :param current_device: (Int) Device to run the model
        :return: (Tensor)
        """
        z = torch.randn(num_samples, self.latent_dim)

        z = z.to(current_device)

        samples = self.decode(z)
        return samples

    def generate(self, x: Tensor, **kwargs) -> Tensor:
        """
        Given an input image x, returns the reconstructed image
        :param x: (Tensor) [B x C x H x W]
        :return: (Tensor) [B x C x H x W]
        """

        return self.forward(x)[0]

In [7]:
image_dir=r"/home/jupyter/datasphere/filestore/fulltiles/images"  # папка с обычными изображениями
topo_dir=r"/home/jupyter/datasphere/filestore/fulltiles/topo"  # папка с топо-изображениями
    
datasetImg = PairedTopoDatasetImg(
    image_dir,
    topo_dir,
    transform = None,
)

In [8]:
batch_size_Img = 4
dataloaderImg = torch.utils.data.DataLoader(
    datasetImg, batch_size_Img, shuffle=True, num_workers=8, pin_memory=True
)

In [9]:
datasetTopo = PairedTopoDatasetTopo(
    image_dir,
    topo_dir,
    transform = None
)

In [10]:
batch_size_Topo = 4
dataloaderTopo = torch.utils.data.DataLoader(
    datasetTopo, batch_size_Topo, shuffle=True, num_workers=8, pin_memory=True
)

In [11]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
modelImg = VanillaVAEImg(in_channels=1, latent_dim=512).to(device)
optimizer = torch.optim.Adam(modelImg.parameters(), lr=1e-4)

In [12]:
modelTopo = VanillaVAETopo(in_channels=1, latent_dim=16).to(device)
optimizer = torch.optim.Adam(modelTopo.parameters(), lr=1e-4)

In [13]:
weights_path_Img = '/home/jupyter/datasphere/project/vanilla_vae_weights120.pth'
modelImg.load_state_dict(torch.load(weights_path_Img))

<All keys matched successfully>

In [14]:
weights_path_Topo = '/home/jupyter/datasphere/project/vanilla_vae_topo_weights50.pth'
modelTopo.load_state_dict(torch.load(weights_path_Topo))

<All keys matched successfully>

In [15]:
modelImg.eval()

VanillaVAEImg(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (3): Sequential(
      (0): Conv2d(128, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (4):

In [16]:
modelTopo.eval()

VanillaVAETopo(
  (encoder): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (1): Sequential(
      (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
    (2): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01)
    )
  )
  (fc_mu): Linear(in_features=256, out_features=16, bias=True)
  (fc_var): Linear(in_features=256, out_features=16, bias=True)
  (decoder_input): Linear(in_features=16, out_features=256, bias=True)
  (decoder): Sequential(
    (0): Sequential(
 

In [17]:
def process_large_image_with_overlap(model, large_img, tile_size=256, overlap=128, device='cuda'):
    """
    Обрабатывает большое изображение тайлами с перекрытием
    tile_size: размер тайла (рекомендуется 256-512 для больших изображений)
    overlap: рекомендуемое значение - половина tile_size
    """
    # Проверяем размерности
    if len(large_img.shape) == 2:
        large_img = large_img.unsqueeze(0).unsqueeze(0)
    elif len(large_img.shape) == 3:
        large_img = large_img.unsqueeze(0)
    
    h, w = large_img.shape[-2:]
    stride = tile_size - overlap
    
    # Добавляем паддинг
    pad_h = (tile_size - h % stride) if h % stride != 0 else 0
    pad_w = (tile_size - w % stride) if w % stride != 0 else 0
    
    if pad_h > 0 or pad_w > 0:
        large_img = F.pad(large_img, (0, pad_w, 0, pad_h), mode='reflect')
    
    # Создаем гауссову маску для плавного смешивания
    y, x = torch.meshgrid(torch.linspace(-1, 1, tile_size),
                      torch.linspace(-1, 1, tile_size),
                      indexing='ij')
    mask = torch.exp(-(x**2 + y**2))
    mask = mask.to(device)
    
    # Инициализируем результат
    result = torch.zeros_like(large_img, device=device)
    count = torch.zeros_like(large_img, device=device)
    
    # Обрабатываем тайлы
    for y in range(0, h + pad_h - tile_size + 1, stride):
        for x in range(0, w + pad_w - tile_size + 1, stride):
            tile = large_img[..., y:y+tile_size, x:x+tile_size].to(device)
            
            with torch.no_grad():
                recon_tile = model(tile)[0].clamp(0, 1)  # Ограничиваем диапазон
            
            result[..., y:y+tile_size, x:x+tile_size] += recon_tile * mask
            count[..., y:y+tile_size, x:x+tile_size] += mask
    
    # Нормализуем и обрезаем паддинг
    result = (result / count)[..., :h, :w]
    
    return result.squeeze().cpu()

In [18]:
def process_large_topo_with_overlap(model, large_img, tile_size=10, overlap=5, device='cuda'):
    """
    Обрабатывает большое изображение тайлами с перекрытием
    tile_size: размер тайла (квадратный)
    overlap: величина перекрытия между тайлами (рекомендуется tile_size//2)
    """
    # Проверяем размерности
    if len(large_img.shape) == 2:
        large_img = large_img.unsqueeze(0).unsqueeze(0)
    elif len(large_img.shape) == 3:
        large_img = large_img.unsqueeze(0)
    
    h, w = large_img.shape[-2:]
    stride = tile_size - overlap
    
    # Добавляем паддинг чтобы покрыть все изображение
    pad_h = (tile_size - h % stride) if h % stride != 0 else 0
    pad_w = (tile_size - w % stride) if w % stride != 0 else 0
    
    if pad_h > 0 or pad_w > 0:
        large_img = F.pad(large_img, (0, pad_w, 0, pad_h), mode='reflect')
    
    # Создаем маску для взвешенного сложения
    mask = torch.ones((1, 1, tile_size, tile_size), device=device)
    mask = F.pad(mask, (overlap//2, overlap//2, overlap//2, overlap//2), value=1e-6)
    mask = mask[:, :, :tile_size, :tile_size]
    
    # Инициализируем результат и счетчик перекрытий
    result = torch.zeros_like(large_img, device=device)
    count = torch.zeros_like(large_img, device=device)
    
    # Собираем координаты тайлов
    coords = []
    for y in range(0, h + pad_h - tile_size + 1, stride):
        for x in range(0, w + pad_w - tile_size + 1, stride):
            coords.append((y, x))
    
    # Обрабатываем тайлы
    for i in range(0, len(coords), 32):  # Пакетами по 32 тайла
        batch_coords = coords[i:i+32]
        tiles = []
        for y, x in batch_coords:
            tile = large_img[..., y:y+tile_size, x:x+tile_size]
            tiles.append(tile)
        
        tiles = torch.cat(tiles, dim=0).to(device)
        
        with torch.no_grad():
            recon_tiles = model(tiles)[0]
        
        # Добавляем обработанные тайлы в результат с учетом маски
        for j, (y, x) in enumerate(batch_coords):
            result[..., y:y+tile_size, x:x+tile_size] += recon_tiles[j] * mask
            count[..., y:y+tile_size, x:x+tile_size] += mask
    
    # Нормализуем результат
    result = result / count
    
    # Убираем паддинг если добавляли
    if pad_h > 0 or pad_w > 0:
        result = result[..., :h, :w]
    
    return result.squeeze().cpu()

In [19]:
def extract_latent_features_img(autoencoder, dataloader, device=device):
    features = []
    filenames = []
    
    with torch.no_grad():
        for img, img_name in tqdm(dataloader, desc="Processing images"):
            img = img.to(device)
            latent = process_large_image_with_overlap(autoencoder, img, tile_size=512, overlap=256, device=device)
            latent = latent.cpu().to(torch.float16)
            features.append(latent.cpu())
            filenames.extend(img_name)
            torch.cuda.empty_cache()
            
    return torch.cat(features), filenames

featuresImg, filenamesImg = extract_latent_features_img(modelImg, dataloaderImg)

Processing images: 100%|██████████| 172/172 [07:23<00:00,  2.58s/it]


In [20]:
def extract_latent_features_topo(autoencoder, dataloader, device=device):
    features = []
    filenames = []
    
    with torch.no_grad():
        for topo, topo_name in tqdm(dataloader, desc="Processing topo"):
            topo = topo.to(device)
            latent = process_large_topo_with_overlap(autoencoder, topo, tile_size=10, overlap=5, device=device)
            latent = latent.cpu().to(torch.float16)
            features.append(latent.cpu())
            filenames.extend(topo_name)
            torch.cuda.empty_cache()
            
    return torch.cat(features), filenames

featuresTopo, filenamesTopo = extract_latent_features_topo(modelTopo, dataloaderTopo)

Processing topo: 100%|██████████| 172/172 [01:09<00:00,  2.48it/s]


In [21]:
def get_img_name(topo_name):
    return topo_name.replace("_topo", "")

for i in range(len(filenamesTopo)):
    filenamesTopo[i] = get_img_name(filenamesTopo[i])

In [25]:
import csv

# Определяем размер чанка (настройте под вашу систему)
CHUNK_SIZE = (featuresImg.shape[1] + featuresTopo.shape[1]) // 1000

# Создаем словари с numpy массивами
latent_dictImg = {name: feat.numpy() for name, feat in zip(filenamesImg, featuresImg)}
latent_dictTopo = {name: feat.numpy() for name, feat in zip(filenamesTopo, featuresTopo)}

# Заголовки столбцов (создаем заранее)
sar_columns = [f'features_SAR_{i}' for i in range(featuresImg.shape[1])]
topo_columns = [f'features_topo_{i}' for i in range(featuresTopo.shape[1])]
fieldnames = ['filename'] + sar_columns + topo_columns

# Открываем CSV для постепенной записи
with open('latent_features_expanded.csv', 'w', newline='') as csvfile:
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
    writer.writeheader()
    
    # Обрабатываем данные чанками
    for i in tqdm(range(0, len(filenamesImg), CHUNK_SIZE), desc="Processing chunks"):
        chunk_data = []
        
        for img_name in filenamesImg[i:i+CHUNK_SIZE]:
            if img_name in latent_dictTopo:
                sar_features = latent_dictImg[img_name].flatten()
                topo_features = latent_dictTopo[img_name].flatten()
                
                record = {'filename': img_name}
                record.update(zip(sar_columns, sar_features))
                record.update(zip(topo_columns, topo_features))
                
                writer.writerow(record)

print("Processing completed and saved to latent_features_expanded.csv")

Processing chunks: 100%|██████████| 230/230 [00:04<00:00, 49.32it/s]


Processing completed and saved to latent_features_expanded.csv


In [4]:
import pandas as pd
from pathlib import Path

# 1. Загружаем основной файл
df_main = pd.read_csv('latent_features_expanded.csv')

# 2. Извлекаем ID из столбца filename
df_main['uniq_ID'] = df_main['filename'].str.extract(r'^(cor_\d+|vlc_\d+|vlcInt_\d+)')

# 3. Загружаем и объединяем дополнительные файлы
df_extra_combined = pd.DataFrame()
required_columns = ['uniq_ID', 'Latitude', 'Longitude', 'Max_D_km', 'Min_D_km']  # Обязательные столбцы

for file in Path('/home/jupyter/datasphere/project/Volc_centers_data/').glob('*.csv'):
    df_temp = pd.read_csv(file)
    
    # Проверяем, есть ли все нужные столбцы. Если нет — добавляем и заполняем нулями.
    for col in required_columns:
        if col not in df_temp.columns:
            df_temp[col] = 0
    
    df_extra_combined = pd.concat([df_extra_combined, df_temp], ignore_index=True)

# 4. Объединяем с основным DataFrame (left join)
df_main = pd.merge(
    df_main,
    df_extra_combined[required_columns],
    on='uniq_ID',
    how='left'
)

# 5. Удаляем временный столбец 'ID' (если не нужен)
df_main.drop('uniq_ID', axis=1, inplace=True)

✅ Файл успешно дополнен, отсутствующие столбцы заполнены нулями!


In [None]:
# Функция для извлечения типа файла
def extract_type(filename):
    parts = str(filename).split('_', 1)  # Разделить только один раз по первому "_"
    return parts[0]

# Применение функции к столбцу filename
df_main['type'] = df_main['filename'].apply(extract_type)

# Переставляем столбцы так, чтобы 'type' шёл сразу после 'filename'
cols = ['filename', 'type'] + [col for col in data.columns if col not in ['filename', 'type']]
df_main = df_main[cols]

# Просмотр первых строк
print(df_main.head())

file_name = 'latent_features_expanded_with_type.csv'

# Сохранение в новый файл
df_main.to_csv(file_name, index=False)

print(f"Датасет успешно создан и сохранен в файл {file_name}!")