In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install segmentation_models_pytorch
!pip install git+https://github.com/lucasb-eyer/pydensecrf.git

Collecting segmentation_models_pytorch
  Downloading segmentation_models_pytorch-0.5.0-py3-none-any.whl.metadata (17 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch>=1.8->segmentation_models_pytorch)
  Downloading nvidia_cublas_cu12-12.4.5.8-

In [None]:
import numpy as np
import segmentation_models_pytorch as smp
import torch
import os
import cv2
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.nn.functional as F
import glob
from tqdm import tqdm
import albumentations as A
import tifffile
from torch.amp import autocast
from torch.amp import GradScaler
import logging
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_softmax, create_pairwise_gaussian, create_pairwise_bilateral

  check_for_updates()


In [None]:
!cp -r /content/drive/MyDrive/datasets_and_models/HandLabeled.zip /content/

In [None]:
!unzip /content/HandLabeled.zip -d /

[1;30;43mВыходные данные были обрезаны до нескольких последних строк (5000).[0m
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Paraguay_2_791364_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Ghana_2_146222_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Somalia_3_989553_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Spain_2_3285448_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/USA_2_994009_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Ghana_3_187318_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/Sri-Lanka_2_163406_LabelHand.tif  
 extracting: /content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Pseudo-labels/Label/India_1_1072277_LabelHand.tif

In [None]:
# Параметры конфигурации
train_dir = "/content/dataset_sen1floods11_RGB_NIR_VV_VH/train/Handlabeled"
validation_dir = "/content/dataset_sen1floods11_RGB_NIR_VV_VH/validation"
test_dir = "/content/dataset_sen1floods11_RGB_NIR_VV_VH/test"
test_Bolivia_dir = "/content/dataset_sen1floods11_RGB_NIR_VV_VH/test_Bolivia"
local_batch_size = 64
learning_rate = 1e-3
num_epochs = 200
model_serialization = "PAN_EfficientNet-B7"
folder = "PAN_EfficientNet-B7"
channels = "RGB_NIR"
data = "HandLabeled"

In [None]:
# Предопределённые списки для преобразований D4 и их инверсии
D4_TRANSFORMS = [
    lambda x: x,                                 # I
    lambda x: torch.rot90(x, 1, [2, 3]),         # R90
    lambda x: torch.rot90(x, 2, [2, 3]),         # R180
    lambda x: torch.rot90(x, 3, [2, 3]),         # R270
    lambda x: torch.flip(x, [3]),                # F
    lambda x: torch.rot90(torch.flip(x, [3]), 1, [2, 3]),  # F_R90
    lambda x: torch.rot90(torch.flip(x, [3]), 2, [2, 3]),  # F_R180
    lambda x: torch.rot90(torch.flip(x, [3]), 3, [2, 3]),  # F_R270
]

INVERSE_TRANSFORMS = [
    lambda x: x,                                 # I
    lambda x: torch.rot90(x, -1, [2, 3]),        # R90
    lambda x: torch.rot90(x, -2, [2, 3]),        # R180
    lambda x: torch.rot90(x, -3, [2, 3]),        # R270
    lambda x: torch.flip(x, [3]),                # F
    lambda x: torch.flip(torch.rot90(x, -1, [2, 3]), [3]),  # F_R90
    lambda x: torch.flip(torch.rot90(x, -2, [2, 3]), [3]),  # F_R180
    lambda x: torch.flip(torch.rot90(x, -3, [2, 3]), [3]),  # F_R270
]

In [None]:
transform = A.Compose([
    A.HorizontalFlip(p = 0.5),
    A.RandomRotate90(p = 0.5),
    A.VerticalFlip(p = 0.5)
])

In [None]:
def get_filename(filepath):
    return os.path.split(filepath)[1]

In [None]:
def create_df(main_dir):
    rgb_nir_image_paths = sorted(glob.glob(main_dir + '/RGB+NIR/*.tif', recursive=True))
    rgb_nir_image_names = [get_filename(pth) for pth in rgb_nir_image_paths]
    mask_paths = []

    for i in range(len(rgb_nir_image_paths)):
        # Путь к изображению vh
        mask_name = rgb_nir_image_names[i]
        parts = mask_name.split('_')
        mask_name = '_'.join(parts[:3]) + '_'
        mask_path = os.path.join(
            main_dir, "Label", f"{mask_name}LabelHand.tif"
        )
        mask_paths.append(mask_path)

    paths = {
        "rgb_nir_image_path": rgb_nir_image_paths,
        "mask_path": mask_paths,
    }

    return pd.DataFrame(paths)

In [None]:
class MyDataset(Dataset):
    def __init__(self, dataframe, split, transform=None):
        self.split = split
        self.dataset = dataframe
        self.transform = transform

    def __len__(self):
        return self.dataset.shape[0]

    def __getitem__(self, index):

        row = self.dataset.iloc[index]
        img_path = row["rgb_nir_image_path"]
        mask_path = row["mask_path"]

        # Загружаем 4-канальное изображение TIFF (R, G, B, NIR)
        image = tifffile.imread(img_path).astype(np.float32)
        # Нормируем [0, 1]
        image = image/10000.0
        image = np.clip(image, 0.0, 1.0)
        image = image.transpose(1, 2, 0)
        # Загружаем маску (одноканальная)
        mask = tifffile.imread(mask_path).astype(np.int64)

        # Аугментация только для train
        if self.split == "train" and self.transform is not None:
            augmented = self.transform(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]
        # Транспонируем каналы для PyTorch (C, H, W)
        image = image.transpose(2, 0, 1)

        return {
                "image": torch.from_numpy(image.astype(np.float32)),
                "mask": torch.from_numpy(mask).long()
            }

In [None]:
def calculate_TP_FP_TN_FN(pred, target):
    # Получаем предсказанные метки (из логитов)
    pred_labels = pred.argmax(dim=1)  # dim=1 — классы, на выходе (B, H, W)

    TP = ((pred_labels == 1) & (target == 1)).sum().item()
    FP = ((pred_labels == 1) & (target == 0)).sum().item()
    TN = ((pred_labels == 0) & (target == 0)).sum().item()
    FN = ((pred_labels == 0) & (target == 1)).sum().item()

    return float(TP), float(FP), float(TN), float(FN), float(TN), float(FN), float(TP), float(FP)

In [None]:
def calculate_metrics(TP_flood, FP_flood, TN_flood, FN_flood, TP_background, FP_background, TN_background, FN_background):
    Accuracy = (TP_flood + TN_flood)/(TP_flood + FP_flood + TN_flood + FN_flood + 1e-6)
    IoU_flood = TP_flood/(TP_flood + FP_flood + FN_flood + 1e-6)
    IoU_background = TP_background/(TP_background + FP_background + FN_background + 1e-6)
    IoU = (IoU_flood + IoU_background)/2
    Dice_flood = 2*TP_flood/(2*TP_flood + FP_flood + FN_flood + 1e-6)
    Dice_background = 2*TP_background/(2*TP_background + FP_background + FN_background + 1e-6)
    Dice = (Dice_flood + Dice_background)/2
    Precision_flood = TP_flood/(TP_flood + FP_flood + 1e-6)
    Precision_background = TP_background/(TP_background + FP_background + 1e-6)
    Precision = (Precision_flood + Precision_background)/2
    Recall_flood = TP_flood/(TP_flood + FN_flood + 1e-6)
    Recall_background = TP_background/(TP_background + FN_background + 1e-6)
    Recall = (Recall_flood + Recall_background)/2
    BalancedAccuracy = (TP_flood/(TP_flood + FN_flood + 1e-6) + TN_flood/(TN_flood + FP_flood + 1e-6))/2

    return Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy

In [None]:
def d4_mask_batch(images: torch.Tensor, model: torch.nn.Module, device="cuda" if torch.cuda.is_available() else "cpu") -> torch.Tensor:
    """
    Test-Time D4-аугментация: возвращает усреднённые вероятности классов (B,2,H,W).
    """
    model.eval()
    images = images.to(device)
    B, C, H, W = images.shape
    acc = torch.zeros((B, 2, H, W), dtype=torch.float32, device=device)

    with torch.no_grad():
        for transform, inv in zip(D4_TRANSFORMS, INVERSE_TRANSFORMS):

            imgs_t = transform(images)

            logits = model(imgs_t)                  # (B,2,H,W)
            prob = F.softmax(logits, dim=1)
            prob_inv = inv(prob)
            acc += prob_inv

    # Усреднение по 8 вариантам
    avg = acc / len(D4_TRANSFORMS)          # (B,2,H,W)
    return avg

In [None]:
def apply_crf_to_batch(images, probs, n_iters=10):
    """
    Применяет CRF к батчу изображений.

    :param images: тензор изображений формы (B, C, H, W), где C >= 3
    :param probs: тензор вероятностей классов формы (B, 2, H, W)
    :param n_iters: количество итераций для CRF
    :return: тензор уточнённых масок формы (B, H, W)
    """
    B, C, H, W = probs.shape
    device = images.device
    refined_masks = torch.zeros((B, 2, H, W), dtype = torch.uint8, device = device)
    for i in range(B):
        # Извлекаем первые три канала изображения и масштабируем в диапазон [0, 255]
        image_rgb = (images[i, 1:4, :, :] * 255).cpu().numpy().transpose(1, 2, 0).copy().astype(np.uint8)

        # Получаем вероятности классов
        prob = probs[i].cpu().numpy()

        # Инициализация CRF
        d = dcrf.DenseCRF2D(W, H, 2)
        unary = unary_from_softmax(prob)
        d.setUnaryEnergy(unary)

        # Добавление парных потенциалов
        d.addPairwiseGaussian(sxy=3, compat=3)
        d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image_rgb, compat=10)

        # Инференс CRF
        Q = d.inference(n_iters)
        refined_mask = np.array(Q).reshape((2, H, W)).astype(np.float32)

        refined_masks[i] = torch.tensor(refined_mask, dtype = torch.float32, device = device)

    return refined_masks

In [None]:
def calculate_flood_percentage(mask):
    flood_pixels = np.sum(mask == 1)
    total_pixels = mask.size
    return (flood_pixels / total_pixels) * 100

In [None]:
def visualize_predictions(images, masks, predictions):
    """
    Визуализация входных изображений, их масок и предсказаний модели.
    Порядок отображения: RGB-изображение, NIR-канал, маска, предсказание модели.
    """
    cmap = ListedColormap(['black', '#ffc0cb', '#0000FF'])  # Цвета для -1, 0, 1
    bounds = [-1.5, -0.5, 0.5, 1.5]  # Границы значений
    norm = BoundaryNorm(bounds, cmap.N)

    batch_size = images.shape[0]
    for i in range(batch_size):
        # Извлечение изображения и маски
        image = images[i].permute(1, 2, 0).cpu().numpy()  # Перестановка осей и перевод на CPU
        mask = masks[i].cpu().numpy()  # Перевод маски на CPU
        prediction = predictions[i].argmax(dim=0).cpu().numpy()  # Перевод предсказания на CPU

        # Извлечение RGB и NIR каналов
        rgb_image = image[:, :, :3]
        if image.shape[2] > 3:
            nir_channel = image[:, :, 3]
        else:
            nir_channel = None  # Если NIR-канал отсутствует

        # Расчёт процента затопления
        flood_percentage = calculate_flood_percentage(prediction)

        # Отображение изображений
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))

        # RGB-изображение
        rgb_image = np.clip(rgb_image * 6.0, 0, 1)
        axes[0].imshow(rgb_image)
        axes[0].set_title("RGB изображение")
        axes[0].axis("off")

        # NIR-канал
        if nir_channel is not None:
            axes[1].imshow(nir_channel, cmap="gray")
            axes[1].set_title("NIR канал")
        else:
            axes[1].text(0.5, 0.5, 'NIR канал отсутствует', horizontalalignment='center', verticalalignment='center')
            axes[1].set_title("NIR канал")
        axes[1].axis("off")

        # Маска
        axes[2].imshow(mask, cmap=cmap, norm=norm)
        axes[2].set_title("Маска")
        axes[2].axis("off")

        # Предсказание модели
        axes[3].imshow(prediction, cmap=cmap, norm=norm)
        axes[3].set_title(f"Сегментация модели\n{flood_percentage:.2f}% воды")
        axes[3].axis("off")

        plt.tight_layout()
        plt.show()

In [None]:
def train(num_epochs = 200):

    global validation_iou
    global num_best_epoch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Создание модели
    model = smp.PAN(
        encoder_name= "efficientnet-b7", encoder_weights='imagenet', in_channels=4, classes=2
    ).to(device)

    # Оптимизатор
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # Шедулер
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=25,
        T_mult=1,
        eta_min=5e-6
    )

    scaler = GradScaler('cuda')

    # Функция потерь
    criterion_dice = smp.losses.DiceLoss(mode="multiclass", ignore_index=-1)

    # Создание DataLoader-ов
    train_df = create_df(train_dir)
    validation_df = create_df(validation_dir)
    if train_df.empty:
        raise ValueError("Train DataFrame is empty!")
    if validation_df.empty:
        raise ValueError("Validation DataFrame is empty!")

    train_dataset = MyDataset(train_df, split="train", transform = transform)
    validation_dataset = MyDataset(validation_df, split="validation", transform=None)

    train_loader = DataLoader(
        train_dataset, batch_size=local_batch_size, shuffle=True, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    validation_loader = DataLoader(
        validation_dataset, batch_size=local_batch_size, shuffle=False, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    if len(train_loader) == 0:
        raise ValueError("Train DataLoader is empty!")
    if len(validation_loader) == 0:
        raise ValueError("Validation DataLoader is empty!")

    ## Начало обучения ##
    for epoch in range(num_epochs):
        # Тренировочный этап
        model.train()
        train_losses = []
        TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
        TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0

        progress_bar = tqdm(train_loader, desc="Train", unit="batch", leave=True)
        for batch in progress_bar:

            image = batch["image"].to(device, non_blocking=True)
            mask = batch["mask"].to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast('cuda'):
                pred = model(image)
                loss = criterion_dice(pred, mask)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
            TP_flood += TP_flood_batch
            FP_flood += FP_flood_batch
            TN_flood += TN_flood_batch
            FN_flood += FN_flood_batch
            TP_background += TP_background_batch
            FP_background += FP_background_batch
            TN_background += TN_background_batch
            FN_background += FN_background_batch

        Avg_loss = sum(train_losses) / (len(train_losses) + 1e-6)
        Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
        logger1.info(f"Epoch {epoch+1}, LR: {optimizer.param_groups[0]['lr']:.8f}, Train Loss: {Avg_loss:.4f}, Train IoU: {IoU:.4f}, Train Accuracy: {Accuracy:.4f}, Train Dice: {Dice:.4f}, Train Precision: {Precision:.4f}, Train Recall: {Recall:.4f}, Train BalancedAccuracy: {BalancedAccuracy:.4f}, Train IoU_flood: {IoU_flood:.4f}, Train IoU_background: {IoU_background:.4f}, Train Dice_flood: {Dice_flood:.4f}, Train Dice_background: {Dice_background:.4f}, Train Precision_flood: {Precision_flood:.4f}, Train Precision_background: {Precision_background:.4f}, Train Recall_flood: {Recall_flood:.4f}, Train Recall_background: {Recall_background:.4f}")

        scheduler.step()

        # Оценка на валидационном наборе
        model.eval()
        with torch.no_grad():
            valid_losses = []
            TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
            TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0

            progress_bar = tqdm(validation_loader, desc="Valid", unit="batch", leave=True)
            for batch in progress_bar:
                image = batch["image"].to(device, non_blocking=True)
                mask = batch["mask"].to(device, non_blocking=True)

                with autocast('cuda'):
                    pred = model(image)
                    loss = criterion_dice(pred, mask)

                # Метрики
                valid_losses.append(loss.item())
                TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
                TP_flood += TP_flood_batch
                FP_flood += FP_flood_batch
                TN_flood += TN_flood_batch
                FN_flood += FN_flood_batch
                TP_background += TP_background_batch
                FP_background += FP_background_batch
                TN_background += TN_background_batch
                FN_background += FN_background_batch

        Avg_loss = sum(valid_losses) / (len(valid_losses) + 1e-6)
        Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
        if (IoU > validation_iou):
            num_best_epoch = epoch + 1
            validation_iou = IoU
        logger1.info(f"Epoch {epoch+1}, Valid Loss: {Avg_loss:.4f}, Valid IoU: {IoU:.4f}, Valid Accuracy: {Accuracy:.4f}, Valid Dice: {Dice:.4f}, Valid Precision: {Precision:.4f}, Valid Recall: {Recall:.4f}, Valid BalancedAccuracy: {BalancedAccuracy:.4f}, Valid IoU_flood: {IoU_flood:.4f}, Valid IoU_background: {IoU_background:.4f}, Valid Dice_flood: {Dice_flood:.4f}, Valid Dice_background: {Dice_background:.4f}, Valid Precision_flood: {Precision_flood:.4f}, Valid Precision_background: {Precision_background:.4f}, Valid Recall_flood: {Recall_flood:.4f}, Valid Recall_background: {Recall_background:.4f}")

        # Сохранение модели
        if (epoch+1)>50:
            torch.save({
                'epoch': epoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': Avg_loss
            }, f"/content/{model_serialization}_{epoch+1}_{channels}_{data}.pth")


In [None]:
def addtrain(num_epochs = 200, initEpoch = 100):

    global validation_iou
    global num_best_epoch

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(f"/content/drive/MyDrive/datasets_and_models/{folder}/{folder_of_weights}/{model_serialization}_{initEpoch}_{channels}_{data}.pth", map_location=device)

    # Создание модели
    model = smp.PAN(
        encoder_name= "efficientnet-b7", encoder_weights='imagenet', in_channels=4, classes=2
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Оптимизатор
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    # Шедулер
    scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
        optimizer,
        T_0=25,
        T_mult=1,
        eta_min=5e-6
    )
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    scaler = GradScaler('cuda')

    # Функция потерь
    criterion_dice = smp.losses.DiceLoss(mode="multiclass", ignore_index=-1)

    # Создание DataLoader-ов
    train_df = create_df(train_dir)
    validation_df = create_df(validation_dir)
    if train_df.empty:
        raise ValueError("Train DataFrame is empty!")
    if validation_df.empty:
        raise ValueError("Validation DataFrame is empty!")

    train_dataset = MyDataset(train_df, split="train", transform = transform)
    validation_dataset = MyDataset(validation_df, split="validation", transform=None)

    train_loader = DataLoader(
        train_dataset, batch_size=local_batch_size, shuffle=True, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    validation_loader = DataLoader(
        validation_dataset, batch_size=local_batch_size, shuffle=False, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    if len(train_loader) == 0:
        raise ValueError("Train DataLoader is empty!")
    if len(validation_loader) == 0:
        raise ValueError("Validation DataLoader is empty!")

    ## Начало обучения ##
    for epoch in range(num_epochs):
        # Тренировочный этап
        model.train()
        train_losses = []
        TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
        TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0

        progress_bar = tqdm(train_loader, desc="Train", unit="batch", leave=True)
        for batch in progress_bar:

            image = batch["image"].to(device, non_blocking=True)
            mask = batch["mask"].to(device, non_blocking=True)

            optimizer.zero_grad()

            with autocast('cuda'):
                pred = model(image)
                loss = criterion_dice(pred, mask)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_losses.append(loss.item())
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
            TP_flood += TP_flood_batch
            FP_flood += FP_flood_batch
            TN_flood += TN_flood_batch
            FN_flood += FN_flood_batch
            TP_background += TP_background_batch
            FP_background += FP_background_batch
            TN_background += TN_background_batch
            FN_background += FN_background_batch

        Avg_loss = sum(train_losses) / (len(train_losses) + 1e-6)
        Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
        logger1.info(f"Epoch {epoch+initEpoch+1}, LR: {optimizer.param_groups[0]['lr']:.8f}, Train Loss: {Avg_loss:.4f}, Train IoU: {IoU:.4f}, Train Accuracy: {Accuracy:.4f}, Train Dice: {Dice:.4f}, Train Precision: {Precision:.4f}, Train Recall: {Recall:.4f}, Train BalancedAccuracy: {BalancedAccuracy:.4f}, Train IoU_flood: {IoU_flood:.4f}, Train IoU_background: {IoU_background:.4f}, Train Dice_flood: {Dice_flood:.4f}, Train Dice_background: {Dice_background:.4f}, Train Precision_flood: {Precision_flood:.4f}, Train Precision_background: {Precision_background:.4f}, Train Recall_flood: {Recall_flood:.4f}, Train Recall_background: {Recall_background:.4f}")

        scheduler.step()

        # Оценка на валидационном наборе
        model.eval()
        with torch.no_grad():
            valid_losses = []
            TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
            TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0

            progress_bar = tqdm(validation_loader, desc="Valid", unit="batch", leave=True)
            for batch in progress_bar:
                image = batch["image"].to(device, non_blocking=True)
                mask = batch["mask"].to(device, non_blocking=True)

                with autocast('cuda'):
                    pred = model(image)
                    loss = criterion_dice(pred, mask)

                # Метрики
                valid_losses.append(loss.item())
                TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
                TP_flood += TP_flood_batch
                FP_flood += FP_flood_batch
                TN_flood += TN_flood_batch
                FN_flood += FN_flood_batch
                TP_background += TP_background_batch
                FP_background += FP_background_batch
                TN_background += TN_background_batch
                FN_background += FN_background_batch

        Avg_loss = sum(valid_losses) / (len(valid_losses) + 1e-6)
        Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
        if (IoU > validation_iou):
            num_best_epoch = epoch + initEpoch +1
            validation_iou = IoU
        logger1.info(f"Epoch {epoch+initEpoch+1}, Valid Loss: {Avg_loss:.4f}, Valid IoU: {IoU:.4f}, Valid Accuracy: {Accuracy:.4f}, Valid Dice: {Dice:.4f}, Valid Precision: {Precision:.4f}, Valid Recall: {Recall:.4f}, Valid BalancedAccuracy: {BalancedAccuracy:.4f}, Valid IoU_flood: {IoU_flood:.4f}, Valid IoU_background: {IoU_background:.4f}, Valid Dice_flood: {Dice_flood:.4f}, Valid Dice_background: {Dice_background:.4f}, Valid Precision_flood: {Precision_flood:.4f}, Valid Precision_background: {Precision_background:.4f}, Valid Recall_flood: {Recall_flood:.4f}, Valid Recall_background: {Recall_background:.4f}")

        # Сохранение модели
        if (epoch+1+initEpoch)>50:
            torch.save({
                'epoch': epoch+initEpoch+1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'loss': Avg_loss
            }, f"/content/{model_serialization}_{epoch+initEpoch+1}_{channels}_{data}.pth")


In [None]:
def test(TestEpoch = 200):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    checkpoint = torch.load(f"/content/drive/MyDrive/datasets_and_models/{folder}/{folder_of_weights}/{model_serialization}_{TestEpoch}_{channels}_{data}.pth", map_location=device)

    # Создание модели
    model = smp.PAN(
        encoder_name= "efficientnet-b7", encoder_weights='imagenet', in_channels=4, classes=2
    ).to(device)
    model.load_state_dict(checkpoint['model_state_dict'])

    scaler = GradScaler('cuda')

    # Создание DataLoader-ов
    test_df = create_df(test_dir)
    test_Bolivia_df = create_df(test_Bolivia_dir)
    if test_df.empty:
        raise ValueError("Test DataFrame is empty!")
    if  test_Bolivia_df.empty:
        raise ValueError("Test Bolivia DataFrame is empty!")

    test_dataset = MyDataset(test_df, split="test", transform = None)
    test_Bolivia_dataset = MyDataset(test_Bolivia_df, split="test", transform=None)

    test_loader = DataLoader(
        test_dataset, batch_size=16, shuffle=False, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    test_Bolivia_loader = DataLoader(
        test_Bolivia_dataset, batch_size=16, shuffle=False, num_workers = os.cpu_count()-1,
        pin_memory=True, persistent_workers=True
    )
    if len(test_loader) == 0:
        raise ValueError("Test DataLoader is empty!")
    if len(test_Bolivia_loader) == 0:
        raise ValueError("Test Bolivia DataLoader is empty!")

    # Оценка на тестовом наборе
    model.eval()
    with torch.no_grad():
        test_losses = []
        TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
        TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0
        TP_flood_aug, FP_flood_aug, TN_flood_aug, FN_flood_aug = 0.0, 0.0, 0.0, 0.0
        TP_background_aug, FP_background_aug, TN_background_aug, FN_background_aug = 0.0, 0.0, 0.0, 0.0
        TP_flood_aug_crf, FP_flood_aug_crf, TN_flood_aug_crf, FN_flood_aug_crf = 0.0, 0.0, 0.0, 0.0
        TP_background_aug_crf, FP_background_aug_crf, TN_background_aug_crf, FN_background_aug_crf = 0.0, 0.0, 0.0, 0.0

        progress_bar = tqdm(test_loader, desc="Test", unit="batch", leave=True)
        for batch in progress_bar:
            image = batch["image"].to(device, non_blocking=True)
            mask = batch["mask"].to(device, non_blocking=True)

            with autocast('cuda'):
                pred = model(image)

            pred_aug = d4_mask_batch(images = image, model = model, device = device)
            pred_aug_CRF = apply_crf_to_batch(images = image, probs = pred_aug)

            #visualize_predictions(images = image , masks = mask, predictions = pred)
            #visualize_predictions(images = image , masks = mask, predictions = pred_aug)
            #visualize_predictions(images = image , masks = mask, predictions = pred_aug_CRF)

            # Метрики
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
            TP_flood += TP_flood_batch
            FP_flood += FP_flood_batch
            TN_flood += TN_flood_batch
            FN_flood += FN_flood_batch
            TP_background += TP_background_batch
            FP_background += FP_background_batch
            TN_background += TN_background_batch
            FN_background += FN_background_batch
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred_aug, mask)
            TP_flood_aug += TP_flood_batch
            FP_flood_aug += FP_flood_batch
            TN_flood_aug += TN_flood_batch
            FN_flood_aug += FN_flood_batch
            TP_background_aug += TP_background_batch
            FP_background_aug += FP_background_batch
            TN_background_aug += TN_background_batch
            FN_background_aug += FN_background_batch
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred_aug_CRF, mask)
            TP_flood_aug_crf += TP_flood_batch
            FP_flood_aug_crf += FP_flood_batch
            TN_flood_aug_crf += TN_flood_batch
            FN_flood_aug_crf += FN_flood_batch
            TP_background_aug_crf += TP_background_batch
            FP_background_aug_crf += FP_background_batch
            TN_background_aug_crf += TN_background_batch
            FN_background_aug_crf += FN_background_batch
    print(f"TP_flood = {TP_flood}, FP_flood = {FP_flood}, TN_flood = {TN_flood}, FN_flood = {FN_flood}")
    print(f"TP_flood_aug = {TP_flood_aug}, FP_flood_aug = {FP_flood_aug}, TN_flood_aug = {TN_flood_aug}, FN_flood_aug = {FN_flood_aug}")
    print(f"TP_floodaug_crf = {TP_flood_aug_crf}, FP_floodaug_crf = {FP_flood_aug_crf}, TN_floodaug_crf = {TN_flood_aug_crf}, FN_floodaug_crf = {FN_flood_aug_crf}")
    Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
    Accuracy_aug, IoU_flood_aug, IoU_background_aug, IoU_aug, Dice_flood_aug, Dice_background_aug, Dice_aug, Precision_flood_aug, Precision_background_aug, Precision_aug, Recall_flood_aug, Recall_background_aug, Recall_aug, BalancedAccuracy_aug = calculate_metrics(TP_flood = TP_flood_aug, FP_flood = FP_flood_aug, TN_flood = TN_flood_aug, FN_flood = FN_flood_aug, TP_background = TP_background_aug, FP_background = FP_background_aug, TN_background = TN_background_aug, FN_background = FN_background_aug)
    Accuracy_aug_crf, IoU_flood_aug_crf, IoU_background_aug_crf, IoU_aug_crf, Dice_flood_aug_crf, Dice_background_aug_crf, Dice_aug_crf, Precision_flood_aug_crf, Precision_background_aug_crf, Precision_aug_crf, Recall_flood_aug_crf, Recall_background_aug_crf, Recall_aug_crf, BalancedAccuracy_aug_crf = calculate_metrics(TP_flood = TP_flood_aug_crf, FP_flood = FP_flood_aug_crf, TN_flood = TN_flood_aug_crf, FN_flood = FN_flood_aug_crf, TP_background = TP_background_aug_crf, FP_background = FP_background_aug_crf, TN_background = TN_background_aug_crf, FN_background = FN_background_aug_crf)
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, Test IoU: {IoU:.4f}, Test Accuracy: {Accuracy:.4f}, Test Dice: {Dice:.4f}, Test Precision: {Precision:.4f}, Test Recall: {Recall:.4f}, Test BalancedAccuracy: {BalancedAccuracy:.4f}, Test IoU_flood: {IoU_flood:.4f}, Test IoU_background: {IoU_background:.4f}, Test Dice_flood: {Dice_flood:.4f}, Test Dice_background: {Dice_background:.4f}, Test Precision_flood: {Precision_flood:.4f}, Test Precision_background: {Precision_background:.4f}, Test Recall_flood: {Recall_flood:.4f}, Test Recall_background: {Recall_background:.4f}")
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, Test Aug IoU: {IoU_aug:.4f}, Test Accuracy: {Accuracy_aug:.4f}, Test Dice: {Dice_aug:.4f}, Test Precision: {Precision_aug:.4f}, Test Recall: {Recall_aug:.4f}, Test BalancedAccuracy: {BalancedAccuracy_aug:.4f}, Test IoU_flood: {IoU_flood_aug:.4f}, Test IoU_background: {IoU_background_aug:.4f}, Test Dice_flood: {Dice_flood_aug:.4f}, Test Dice_background: {Dice_background_aug:.4f}, Test Precision_flood: {Precision_flood_aug:.4f}, Test Precision_background: {Precision_background_aug:.4f}, Test Recall_flood: {Recall_flood_aug:.4f}, Test Recall_background: {Recall_background_aug:.4f}")
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, Test Aug CRF IoU: {IoU_aug_crf:.4f}, Test Accuracy: {Accuracy_aug_crf:.4f}, Test Dice: {Dice_aug_crf:.4f}, Test Precision: {Precision_aug_crf:.4f}, Test Recall: {Recall_aug_crf:.4f}, Test BalancedAccuracy: {BalancedAccuracy_aug_crf:.4f}, Test IoU_flood: {IoU_flood_aug_crf:.4f}, Test IoU_background: {IoU_background_aug_crf:.4f}, Test Dice_flood: {Dice_flood_aug_crf:.4f}, Test Dice_background: {Dice_background_aug_crf:.4f}, Test Precision_flood: {Precision_flood_aug_crf:.4f}, Test Precision_background: {Precision_background_aug_crf:.4f}, Test Recall_flood: {Recall_flood_aug_crf:.4f}, Test Recall_background: {Recall_background_aug_crf:.4f}")

    # Оценка на тестовом наборе Боливии
    model.eval()
    with torch.no_grad():
        test_Bolivia_losses = []
        TP_flood, FP_flood, TN_flood, FN_flood = 0.0, 0.0, 0.0, 0.0
        TP_background, FP_background, TN_background, FN_background = 0.0, 0.0, 0.0, 0.0

        progress_bar = tqdm(test_Bolivia_loader, desc="Test Bolivia", unit="batch", leave=True)
        for batch in progress_bar:
            image = batch["image"].to(device, non_blocking=True)
            mask = batch["mask"].to(device, non_blocking=True)

            with autocast('cuda'):
                pred = model(image)

            pred_aug = d4_mask_batch(images = image, model = model, device = device)
            pred_aug_CRF = apply_crf_to_batch(images = image, probs = pred_aug)

            #visualize_predictions(images = image , masks = mask, predictions = pred)
            #visualize_predictions(images = image , masks = mask, predictions = pred_aug)
            #visualize_predictions(images = image , masks = mask, predictions = pred_aug_CRF)

            # Метрики
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred, mask)
            TP_flood += TP_flood_batch
            FP_flood += FP_flood_batch
            TN_flood += TN_flood_batch
            FN_flood += FN_flood_batch
            TP_background += TP_background_batch
            FP_background += FP_background_batch
            TN_background += TN_background_batch
            FN_background += FN_background_batch
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred_aug, mask)
            TP_flood_aug += TP_flood_batch
            FP_flood_aug += FP_flood_batch
            TN_flood_aug += TN_flood_batch
            FN_flood_aug += FN_flood_batch
            TP_background_aug += TP_background_batch
            FP_background_aug += FP_background_batch
            TN_background_aug += TN_background_batch
            FN_background_aug += FN_background_batch
            TP_flood_batch, FP_flood_batch, TN_flood_batch, FN_flood_batch, TP_background_batch, FP_background_batch, TN_background_batch, FN_background_batch = calculate_TP_FP_TN_FN(pred_aug_CRF, mask)
            TP_flood_aug_crf += TP_flood_batch
            FP_flood_aug_crf += FP_flood_batch
            TN_flood_aug_crf += TN_flood_batch
            FN_flood_aug_crf += FN_flood_batch
            TP_background_aug_crf += TP_background_batch
            FP_background_aug_crf += FP_background_batch
            TN_background_aug_crf += TN_background_batch
            FN_background_aug_crf += FN_background_batch

    Accuracy, IoU_flood, IoU_background, IoU, Dice_flood, Dice_background, Dice, Precision_flood, Precision_background, Precision, Recall_flood, Recall_background, Recall, BalancedAccuracy = calculate_metrics(TP_flood = TP_flood, FP_flood = FP_flood, TN_flood = TN_flood, FN_flood = FN_flood, TP_background = TP_background, FP_background = FP_background, TN_background = TN_background, FN_background = FN_background)
    Accuracy_aug, IoU_flood_aug, IoU_background_aug, IoU_aug, Dice_flood_aug, Dice_background_aug, Dice_aug, Precision_flood_aug, Precision_background_aug, Precision_aug, Recall_flood_aug, Recall_background_aug, Recall_aug, BalancedAccuracy_aug = calculate_metrics(TP_flood = TP_flood_aug, FP_flood = FP_flood_aug, TN_flood = TN_flood_aug, FN_flood = FN_flood_aug, TP_background = TP_background_aug, FP_background = FP_background_aug, TN_background = TN_background_aug, FN_background = FN_background_aug)
    Accuracy_aug_crf, IoU_flood_aug_crf, IoU_background_aug_crf, IoU_aug_crf, Dice_flood_aug_crf, Dice_background_aug_crf, Dice_aug_crf, Precision_flood_aug_crf, Precision_background_aug_crf, Precision_aug_crf, Recall_flood_aug_crf, Recall_background_aug_crf, Recall_aug_crf, BalancedAccuracy_aug_crf = calculate_metrics(TP_flood = TP_flood_aug_crf, FP_flood = FP_flood_aug_crf, TN_flood = TN_flood_aug_crf, FN_flood = FN_flood_aug_crf, TP_background = TP_background_aug_crf, FP_background = FP_background_aug_crf, TN_background = TN_background_aug_crf, FN_background = FN_background_aug_crf)
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, TestBolivia IoU: {IoU:.4f}, TestB Accuracy: {Accuracy:.4f}, TestB Dice: {Dice:.4f}, TestB Precision: {Precision:.4f}, TestB Recall: {Recall:.4f}, TestB BalancedAccuracy: {BalancedAccuracy:.4f}, TestB IoU_flood: {IoU_flood:.4f}, TestB IoU_background: {IoU_background:.4f}, TestB Dice_flood: {Dice_flood:.4f}, TestB Dice_background: {Dice_background:.4f}, TestB Precision_flood: {Precision_flood:.4f}, TestB Precision_background: {Precision_background:.4f}, TestB Recall_flood: {Recall_flood:.4f}, TestB Recall_background: {Recall_background:.4f}")
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, TestBolivia Aug IoU: {IoU_aug:.4f}, TestB Accuracy: {Accuracy_aug:.4f}, TestB Dice: {Dice_aug:.4f}, TestB Precision: {Precision_aug:.4f}, TestB Recall: {Recall_aug:.4f}, TestB BalancedAccuracy: {BalancedAccuracy_aug:.4f}, TestB IoU_flood: {IoU_flood_aug:.4f}, TestB IoU_background: {IoU_background_aug:.4f}, TestB Dice_flood: {Dice_flood_aug:.4f}, TestB Dice_background: {Dice_background_aug:.4f}, TestB Precision_flood: {Precision_flood_aug:.4f}, TestB Precision_background: {Precision_background_aug:.4f}, TestB Recall_flood: {Recall_flood_aug:.4f}, TestB Recall_background: {Recall_background_aug:.4f}")
    logger2.info(f"WeightsEpoch: {checkpoint['epoch']}, TestBolivia Aug CRF IoU: {IoU_aug_crf:.4f}, TestB Accuracy: {Accuracy_aug_crf:.4f}, TestB Dice: {Dice_aug_crf:.4f}, TestB Precision: {Precision_aug_crf:.4f}, TestB Recall: {Recall_aug_crf:.4f}, TestB BalancedAccuracy: {BalancedAccuracy_aug_crf:.4f}, TestB IoU_flood: {IoU_flood_aug_crf:.4f}, TestB IoU_background: {IoU_background_aug_crf:.4f}, TestB Dice_flood: {Dice_flood_aug_crf:.4f}, TestB Dice_background: {Dice_background_aug_crf:.4f}, TestB Precision_flood: {Precision_flood_aug_crf:.4f}, TestB Precision_background: {Precision_background_aug_crf:.4f}, TestB Recall_flood: {Recall_flood_aug_crf:.4f}, TestB Recall_background: {Recall_background_aug_crf:.4f}")


In [None]:
if __name__ == "__main__":

    folder_of_weights = "Weights_RGB_NIR_HandLabeled"

    num_best_epoch = 0
    validation_iou = 0
    # Настройка логирования train
    logger1 = logging.getLogger(f"training_logger")
    logger1.setLevel(logging.INFO)
    # Формат логов
    formatter = logging.Formatter('%(message)s')
    # Вывод в консоль
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(formatter)
    # Запись в файл
    file_handler = logging.FileHandler(f"/content/{channels}_{data}_train.txt", mode='a')
    file_handler.setFormatter(formatter)
    if not logger1.handlers:
      logger1.addHandler(console_handler)
      logger1.addHandler(file_handler)

    # Настройка логирования test
    logger2 = logging.getLogger(f"testing_logger")
    logger2.setLevel(logging.INFO)
    # Формат логов
    formatter2 = logging.Formatter('%(message)s')
    # Вывод в консоль
    console_handler2 = logging.StreamHandler()
    console_handler2.setFormatter(formatter2)
    # Запись в файл
    file_handler2 = logging.FileHandler(f"/content/{channels}_{data}_test.txt", mode='a')
    file_handler2.setFormatter(formatter2)
    if not logger2.handlers:
      logger2.addHandler(console_handler2)
      logger2.addHandler(file_handler2)

    train(num_epochs)

    txt_train = f"{channels}_{data}_train.txt"
    txt_test = f"{channels}_{data}_test.txt"
    best_weights = f"{model_serialization}_{num_best_epoch}_{channels}_{data}.pth"
    weights = f"{model_serialization}_200_{channels}_{data}.pth"
    txt_dir = f"/content/drive/MyDrive/datasets_and_models/{folder}"
    weights_dir = f"/content/drive/MyDrive/datasets_and_models/{folder}/{folder_of_weights}"


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/106 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/267M [00:00<?, ?B/s]

Train: 100%|██████████| 32/32 [00:22<00:00,  1.42batch/s]
Epoch 1, LR: 0.00050000, Train Loss: 0.2681, Train IoU: 0.6577, Train Accuracy: 0.9180, Train Dice: 0.7642, Train Precision: 0.7615, Train Recall: 0.7669, Train BalancedAccuracy: 0.7669, Train IoU_flood: 0.4023, Train IoU_background: 0.9132, Train Dice_flood: 0.5737, Train Dice_background: 0.9547, Train Precision_flood: 0.5673, Train Precision_background: 0.9558, Train Recall_flood: 0.5804, Train Recall_background: 0.9535
INFO:training_logger:Epoch 1, LR: 0.00050000, Train Loss: 0.2681, Train IoU: 0.6577, Train Accuracy: 0.9180, Train Dice: 0.7642, Train Precision: 0.7615, Train Recall: 0.7669, Train BalancedAccuracy: 0.7669, Train IoU_flood: 0.4023, Train IoU_background: 0.9132, Train Dice_flood: 0.5737, Train Dice_background: 0.9547, Train Precision_flood: 0.5673, Train Precision_background: 0.9558, Train Recall_flood: 0.5804, Train Recall_background: 0.9535
Valid: 100%|██████████| 12/12 [00:02<00:00,  4.05batch/s]
Epoch 1, Va

In [None]:
!mv {txt_train} {txt_dir}/

In [None]:
!mv {best_weights} {weights_dir}/

In [None]:
!mv {weights} {weights_dir}/

In [None]:
test(num_best_epoch)
test(num_epochs)

  checkpoint = torch.load(f"/content/drive/MyDrive/datasets_and_models/{folder}/{folder_of_weights}/{model_serialization}_{TestEpoch}_{channels}_{data}.pth", map_location=device)
Test: 100%|██████████| 23/23 [01:17<00:00,  3.38s/batch]
WeightsEpoch: 140, Test IoU: 0.8727, Test Accuracy: 0.9703, Test Dice: 0.9293, Test Precision: 0.9484, Test Recall: 0.9124, Test BalancedAccuracy: 0.9124, Test IoU_flood: 0.7786, Test IoU_background: 0.9668, Test Dice_flood: 0.8755, Test Dice_background: 0.9831, Test Precision_flood: 0.9200, Test Precision_background: 0.9767, Test Recall_flood: 0.8351, Test Recall_background: 0.9896
INFO:testing_logger:WeightsEpoch: 140, Test IoU: 0.8727, Test Accuracy: 0.9703, Test Dice: 0.9293, Test Precision: 0.9484, Test Recall: 0.9124, Test BalancedAccuracy: 0.9124, Test IoU_flood: 0.7786, Test IoU_background: 0.9668, Test Dice_flood: 0.8755, Test Dice_background: 0.9831, Test Precision_flood: 0.9200, Test Precision_background: 0.9767, Test Recall_flood: 0.8351, Tes

TP_flood = 2143023.0, FP_flood = 186311.0, TN_flood = 17764955.0, FN_flood = 423078.0
TP_flood_aug = 2158295.0, FP_flood_aug = 179313.0, TN_flood_aug = 17771953.0, FN_flood_aug = 407806.0
TP_floodaug_crf = 1656131.0, FP_floodaug_crf = 22930.0, TN_floodaug_crf = 17928336.0, FN_floodaug_crf = 909970.0


Test Bolivia: 100%|██████████| 4/4 [00:13<00:00,  3.30s/batch]
WeightsEpoch: 140, TestBolivia IoU: 0.8760, TestB Accuracy: 0.9650, TestB Dice: 0.9318, TestB Precision: 0.9504, TestB Recall: 0.9154, TestB BalancedAccuracy: 0.9154, TestB IoU_flood: 0.7924, TestB IoU_background: 0.9596, TestB Dice_flood: 0.8842, TestB Dice_background: 0.9794, TestB Precision_flood: 0.9299, TestB Precision_background: 0.9709, TestB Recall_flood: 0.8428, TestB Recall_background: 0.9880
INFO:testing_logger:WeightsEpoch: 140, TestBolivia IoU: 0.8760, TestB Accuracy: 0.9650, TestB Dice: 0.9318, TestB Precision: 0.9504, TestB Recall: 0.9154, TestB BalancedAccuracy: 0.9154, TestB IoU_flood: 0.7924, TestB IoU_background: 0.9596, TestB Dice_flood: 0.8842, TestB Dice_background: 0.9794, TestB Precision_flood: 0.9299, TestB Precision_background: 0.9709, TestB Recall_flood: 0.8428, TestB Recall_background: 0.9880
WeightsEpoch: 140, TestBolivia Aug IoU: 0.8777, TestB Accuracy: 0.9708, TestB Dice: 0.9324, TestB Precisi

TP_flood = 2173538.0, FP_flood = 199973.0, TN_flood = 17751293.0, FN_flood = 392563.0
TP_flood_aug = 2172196.0, FP_flood_aug = 189409.0, TN_flood_aug = 17761857.0, FN_flood_aug = 393905.0
TP_floodaug_crf = 1688438.0, FP_floodaug_crf = 29267.0, TN_floodaug_crf = 17921999.0, FN_floodaug_crf = 877663.0


Test Bolivia: 100%|██████████| 4/4 [00:13<00:00,  3.26s/batch]
WeightsEpoch: 200, TestBolivia IoU: 0.8773, TestB Accuracy: 0.9649, TestB Dice: 0.9326, TestB Precision: 0.9437, TestB Recall: 0.9223, TestB BalancedAccuracy: 0.9223, TestB IoU_flood: 0.7952, TestB IoU_background: 0.9593, TestB Dice_flood: 0.8859, TestB Dice_background: 0.9792, TestB Precision_flood: 0.9136, TestB Precision_background: 0.9739, TestB Recall_flood: 0.8599, TestB Recall_background: 0.9847
INFO:testing_logger:WeightsEpoch: 200, TestBolivia IoU: 0.8773, TestB Accuracy: 0.9649, TestB Dice: 0.9326, TestB Precision: 0.9437, TestB Recall: 0.9223, TestB BalancedAccuracy: 0.9223, TestB IoU_flood: 0.7952, TestB IoU_background: 0.9593, TestB Dice_flood: 0.8859, TestB Dice_background: 0.9792, TestB Precision_flood: 0.9136, TestB Precision_background: 0.9739, TestB Recall_flood: 0.8599, TestB Recall_background: 0.9847
WeightsEpoch: 200, TestBolivia Aug IoU: 0.8792, TestB Accuracy: 0.9710, TestB Dice: 0.9334, TestB Precisi

In [None]:
!mv {txt_test} {txt_dir}/