In [1]:
!pip install monai



In [2]:
import re
import os
import requests
from typing import Dict, Optional
import numpy as np

try:
    from tqdm import tqdm
    _has_tqdm = True
except ImportError:
    _has_tqdm = False


import os
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128'
# или попробовать:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
import matplotlib.pyplot as plt
from sklearn.metrics import average_precision_score, f1_score, jaccard_score
import warnings
import os
import random
import inspect
from typing import List, Tuple, Optional, Union

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch import nn
from tqdm import tqdm

import monai
from monai.networks.nets import SwinUNETR
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from scipy import ndimage as ndi

In [3]:
import os
import re
import zipfile
import gdown
import numpy as np

def download_and_unpack_gdrive_gdown(url_or_id, dest_folder='.', out_filename=None,
                                     save_npy=False, quiet=False):
    """
    Скачивает файл с Google Drive через gdown и распаковывает
    """
    os.makedirs(dest_folder, exist_ok=True)

    # если передали только id — соберём URL
    if re.fullmatch(r"[A-Za-z0-9_-]{10,}", url_or_id):
        url = f"https://drive.google.com/uc?id={url_or_id}"
    else:
        url = url_or_id

    # имя выходного файла
    if out_filename is None:
        # попытка взять имя из URL, иначе дать дефолт
        if "id=" in url:
            out_filename = f"downloaded_{url.split('id=')[-1]}.bin"
        else:
            out_filename = os.path.basename(url) or "downloaded_file.bin"

    out_path = os.path.join(dest_folder, out_filename)

    # Скачиваем
    if not os.path.exists(out_path):
        gdown.download(url, out_path, quiet=quiet)

In [4]:

file_id = "1NOUwAFtU_1KIlf8mQtB5xwSY-3NtYnDw"
url = f"https://drive.google.com/uc?id={file_id}"
download_and_unpack_gdrive_gdown(url_or_id=url,
                                           dest_folder="/content",
                                           out_filename="seis_train.npz",
                                           quiet=False)

out_path = os.path.join("/kaggle/working", "seis_train.npz")


In [5]:

file_id = "1QHE1gn7B8Zt99Cw1KDwyfNv6wqhM5h49"
url = f"https://drive.google.com/uc?id={file_id}"
fault_val = download_and_unpack_gdrive_gdown(url_or_id=url,
                                           dest_folder="/content",
                                           out_filename="fault_train.npz",
                                           quiet=False)


In [6]:

file_id = "1vVOkYAZkq08CtWosLd27Py4quvAaTjOa"
url = f"https://drive.google.com/uc?id={file_id}"
fault_val = download_and_unpack_gdrive_gdown(url_or_id=url,
                                           dest_folder="/content",
                                           out_filename="seis_val.npz",
                                           quiet=False)


In [7]:

file_id = "1YIeB6J69RUozpKwWa84m8ZXgpaQMWmP7"
url = f"https://drive.google.com/uc?id={file_id}"
fault_val = download_and_unpack_gdrive_gdown(url_or_id=url,
                                           dest_folder="/content",
                                           out_filename="fault_val.npz",
                                           quiet=False)


In [8]:
from typing import Union, Sequence, Tuple, Optional, Dict, Iterable
import numpy as np
import os
import random

class SeisFaultDataset:
    """
    Датасет для пар (сейсмика, fault) в виде 3D numpy-массивов.

    Поддерживаемые источники:
    - путь к .npz файлу (будет открыт через np.load(..., mmap_mode='r'))
    - dict-like mapping имя -> np.ndarray

    Парирование:
    - если pairs не переданы, автоматически берётся отсортированное пересечение ключей
      между seismic и fault источниками. Если пересечение пусто — бросается исключение.

    Основные опции:
    - target_shape: размер кропа (z,y,x)
    - random_crop: True — случайные кропы, False — кропы с нулевого/центрального начала
    - seed: для детерминированности
    - normalize_seis: применять z-score нормализацию к сейсмике
    """

    def __init__(
        self,
        seis_source: Union[str, Dict[str, np.ndarray]],
        fault_source: Union[str, Dict[str, np.ndarray]],
        pairs: Optional[Sequence[Tuple[str, str]]] = None,
        target_shape: Tuple[int, int, int] = (128, 128, 128),
        random_crop: bool = True,
        seed: Optional[int] = None,
        normalize_seis: bool = True,
    ):
        # параметры
        self.target_shape = tuple(int(x) for x in target_shape)
        if len(self.target_shape) != 3:
            raise ValueError("target_shape должен быть кортежем длины 3")
        self.random_crop = bool(random_crop)
        self.normalize_seis = bool(normalize_seis)
        self._rng = random.Random(seed)

        # загрузка источников (path -> npz mmap или оставляем dict)
        self._seis_npz = None
        self._fault_npz = None

        if isinstance(seis_source, str):
            if not os.path.exists(seis_source):
                raise FileNotFoundError(seis_source)
            self._seis_npz = np.load(seis_source, mmap_mode='r')
            seis_source = self._seis_npz

        if isinstance(fault_source, str):
            if not os.path.exists(fault_source):
                raise FileNotFoundError(fault_source)
            self._fault_npz = np.load(fault_source, mmap_mode='r')
            fault_source = self._fault_npz

        # ожидаем mapping name->ndarray
        self.seis_source = seis_source
        self.fault_source = fault_source

        # парами: либо явно, либо пересечение ключей
        if pairs is not None:
            self.pairs = list(pairs)
        else:
            seis_keys = self._keys_from_source(self.seis_source)
            fault_keys = self._keys_from_source(self.fault_source)

            # пары (имя, имя)
            self.pairs = [(seis, fault) for seis, fault in zip(seis_keys, fault_keys)]

        if len(self.pairs) == 0:
            raise ValueError("No pairs available")

    # ---- Вспомогательные методы ----

    def _keys_from_source(self, src):
        """Возвращает список ключей для mapping-источника или .files для npz."""
        if hasattr(src, 'files'):
            return list(src.files)
        if isinstance(src, dict):
            return list(src.keys())
        # obj, у которого можно взять ключи через итерацию
        try:
            return list(src.keys())
        except Exception:
            raise ValueError(f"Unsupported source type: {type(src)}")

    def _get_array(self, src, key):
        """Простая загрузка массива; ожидаем numpy.ndarray 3D."""
        arr = src[key]
        if not isinstance(arr, np.ndarray):
            # попытка конвертации (обычно не нужна)
            arr = np.asarray(arr)
        if arr.ndim != 3:
            raise ValueError(f"Ключ {key}: ожидается 3D массив, получено {arr.ndim}D")
        return arr

    def _compute_crop_start(self, shape: Tuple[int, int, int]) -> Tuple[int, int, int]:
        """Вычислить стартовые индексы для кропа (случайно или ноль/центр)."""
        starts = []
        for i in range(3):
            max_start = shape[i] - self.target_shape[i]
            if max_start <= 0:
                starts.append(0)
            else:
                if self.random_crop:
                    starts.append(self._rng.randint(0, max_start))
                else:
                    # по умолчанию — центрированный кроп если есть запас
                    starts.append(max_start // 2)
        return tuple(starts)

    def _zscore_normalize(self, vol: np.ndarray) -> np.ndarray:
        """Z-score нормализация (для сейсмики)."""
        v = vol.astype(np.float32, copy=False)
        mu = v.mean()
        sigma = v.std()
        eps = 1e-8
        return (v - mu) / (sigma + eps)

    # ---- интерфейс Dataset ----

    def __len__(self) -> int:
        return len(self.pairs)

    def __getitem__(self, idx: int):
        """Вернёт (seis_crop, fault_crop, meta)."""
        if idx < 0:
            idx = len(self) + idx
        seis_key, fault_key = self.pairs[idx]

        seis = self._get_array(self.seis_source, seis_key)
        fault = self._get_array(self.fault_source, fault_key)

        # берем кроп по форме seismic-а
        starts = self._compute_crop_start(seis.shape)
        slices = tuple(slice(st, st + ts) for st, ts in zip(starts, self.target_shape))

        seis_crop = seis[slices]
        fault_crop = fault[slices]

        if self.normalize_seis:
            seis_crop = self._zscore_normalize(seis_crop)

        meta = {
            "seis_key": seis_key,
            "fault_key": fault_key,
            "crop_start": starts,
            "original_shapes": (seis.shape, fault.shape),
        }
        return seis_crop, fault_crop, meta

    def close(self):
        """Закрыть mmap .npz если он был открыт."""
        if self._seis_npz is not None:
            try:
                self._seis_npz.close()
            except Exception:
                pass
            self._seis_npz = None
        if self._fault_npz is not None:
            try:
                self._fault_npz.close()
            except Exception:
                pass
            self._fault_npz = None

    def __del__(self):
        self.close()


In [9]:
train_ds = SeisFaultDataset("/content/seis_train.npz", "/content/fault_train.npz",
                      target_shape=(128,128,128), seed=42)

val_ds = SeisFaultDataset("/content/seis_val.npz", "/content/fault_val.npz",
                      target_shape=(128,128,128), seed=42)



In [10]:
def save_checkpoint(path, model, optimizer=None, epoch=None, extra=None):
    """
    Корректно сохраняет state_dict модели (без 'module.'), optimizer и мета.
    """
    model_to_save = model.module if isinstance(model, nn.DataParallel) else model
    ckpt = {"epoch": epoch, "model_state_dict": model_to_save.state_dict()}
    if optimizer is not None:
        ckpt["optimizer_state_dict"] = optimizer.state_dict()
    if extra:
        ckpt.update(extra)
    torch.save(ckpt, path)


def _strip_module_prefix(state_dict):
    # удаляем все префиксы "module." (на случай многократных вхождений)
    new_state = {}
    for k, v in state_dict.items():
        if k.startswith("module."):
            new_state[k[len("module."):]] = v
        else:
            new_state[k] = v
    return new_state


def _move_optimizer_state_to_device(opt_state, device):
    # переносим все тензоры в optimizer.state на нужное устройство
    for state in opt_state.values():
        for k, v in list(state.items()):
            if isinstance(v, torch.Tensor):
                state[k] = v.to(device)


def load_checkpoint(
    path,
    model,
    optimizer=None,
    device=torch.device("cpu"),
    strict=False,
    verbose=True,
):
    """
    Загрузка чекпойнта с учётом DataParallel
    """
    # безопасно загружаем на cpu, а потом переносим model/opt на device
    ckpt = torch.load(path, map_location="cpu")
    if verbose:
        print(f"[ckpt] loaded checkpoint type={type(ckpt)}")

    # извлекаем state_dict модели
    if isinstance(ckpt, dict):
        state = ckpt.get("model_state_dict", ckpt.get("state_dict", None))
        if state is None:
            # возможно, чекпойнт — это просто state_dict
            # проверим: если все значения — тензоры, считаем это state_dict
            if all(isinstance(v, torch.Tensor) for v in ckpt.values()):
                state = ckpt
            else:
                raise RuntimeError("Checkpoint dict doesn't contain recognizable model state.")
    else:
        raise RuntimeError("Checkpoint seems to contain a pickled model object. "
                           "Prefer saving state_dict instead of full model object.")

    # убираем 'module.' если нужно
    state = _strip_module_prefix(state)

    # проверяем размеры ключей и предупреждаем о несоответствиях до загрузки
    model_to_load = model.module if isinstance(model, nn.DataParallel) else model
    model_state = model_to_load.state_dict()
    mismatched_shapes = []
    for k, v in state.items():
        if k in model_state and v.shape != model_state[k].shape:
            mismatched_shapes.append((k, v.shape, model_state[k].shape))
    if mismatched_shapes and verbose:
        print("[ckpt] WARNING: found tensors with mismatched shapes (ckpt vs model):")
        for k, s_ckpt, s_model in mismatched_shapes[:10]:
            print(f"   {k}: ckpt{tuple(s_ckpt)} != model{tuple(s_model)}")
        if len(mismatched_shapes) > 10:
            print(f"   ... and {len(mismatched_shapes)-10} more")

    # загружаем state_dict (strict может быть False, чтобы не падать)
    load_res = model_to_load.load_state_dict(state, strict=strict)
    if verbose:
        # load_state_dict возвращает NamedTuple(missing_keys, unexpected_keys)
        print("[ckpt] load_state_dict result:", load_res)

    # переносим модель на device
    model.to(device)

    # если есть optimizer и он сохранялся — грузим и переносим тензоры оптимизатора
    if optimizer is not None and isinstance(ckpt, dict) and "optimizer_state_dict" in ckpt:
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        _move_optimizer_state_to_device(optimizer.state, device)
        if verbose:
            print("[ckpt] optimizer state loaded and moved to", device)

    return ckpt


In [11]:
import inspect
from typing import Dict, Any, Optional
import torch
import torch.nn as nn

def _get_device(device: Optional[torch.device] = None) -> torch.device:
    if device is None:
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(device, str):
        return torch.device(device)
    return device

def build_swinunetr(params: Dict[str, Any],
                           in_ch: int = 1,
                           out_ch: int = 1,
                           device: Optional[torch.device] = None,
                           wrap_dataparallel: bool = True) -> nn.Module:
    """
    Сборка SwinUNETR из словаря params:
    - фильтрует params по сигнатуре конструктора (чтобы игнорировать лишние ключи);
    - создаёт модель, при наличии >1 GPU оборачивает в DataParallel (если wrap_dataparallel=True);
    - возвращает модель на указанном device.
    """
    device = _get_device(device)
    sig = inspect.signature(SwinUNETR.__init__)
    allowed = {p for p in sig.parameters if p not in ("self", "in_channels", "out_channels")}
    call_kwargs = {k: v for k, v in params.items() if k in allowed}
    model = SwinUNETR(in_ch, out_ch, **call_kwargs)
    if wrap_dataparallel and torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    return model.to(device)


In [None]:

class CombinedLoss(nn.Module):
    def __init__(self, weight_bce=1.0, weight_dice=1.0):
        super().__init__()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.bce = nn.BCEWithLogitsLoss(
            pos_weight=torch.tensor([13.3], device=device)
        )
        self.dice = monai.losses.DiceLoss(sigmoid=True, squared_pred=False, reduction="mean")
        self.w_bce = weight_bce
        self.w_dice = weight_dice

    def forward(self, logits, target):
        if target.dim() == logits.dim() - 1:
            target = target.unsqueeze(1)
        return self.w_bce * self.bce(logits, target) + self.w_dice * self.dice(logits, target)

def train_epoch(model, dataloader, optimizer, loss_fn, epoch):
    model.train()
    running_loss = 0.0
    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc=f"Train E{epoch}")

    for i, batch in pbar:
        if len(batch) == 3:
            imgs, masks, metas = batch
        else:
            imgs, masks = batch

        imgs = imgs.unsqueeze(1).float()
        masks = masks.unsqueeze(1).float()

        imgs = imgs.to(DEVICE, non_blocking=True)
        masks = masks.to(DEVICE, non_blocking=True)

        optimizer.zero_grad()
        try:
            outputs = model(imgs)
        except ValueError as e:
            msg = str(e)
            if "Expected more than 1 spatial element" in msg:
                raise RuntimeError("Runtime forward error: spatial dims too small. Consider increasing ROI or manual model changes.\n" + msg)
            else:
                raise

        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        running_loss += float(loss.item())
        pbar.set_postfix(loss=running_loss / (i + 1))

    return running_loss / len(dataloader)

def validate(model, dataloader):
    model.eval()
    loss_fn = CombinedLoss()
    val_loss = 0.0

    TP = 0
    FP = 0
    FN = 0
    total_pixels = 0

    pbar = tqdm(enumerate(dataloader), total=len(dataloader), desc="Val")
    with torch.no_grad():
        for i, batch in pbar:
            if len(batch) == 3:
                imgs, masks, metas = batch
            else:
                imgs, masks = batch

            imgs = imgs.unsqueeze(1).float().to(DEVICE)
            masks = masks.unsqueeze(1).float().to(DEVICE)

            logits = sliding_window_inference(
                inputs=imgs,
                roi_size=ROI_SIZE,
                sw_batch_size=SW_BATCH_SIZE,
                predictor=model,
                overlap=OVERLAP,
            )

            # loss
            val_loss += float(loss_fn(logits, masks).item())

            # preds
            probs = torch.sigmoid(logits)
            preds = (probs >= 0.5).float()

            # приводим к 1D
            preds_flat = preds.view(-1)
            masks_flat = masks.view(-1)

            # аккуратно считаем TP/FP/FN на GPU, затем переводим в int
            TP += int(((preds_flat == 1.0) & (masks_flat == 1.0)).sum().item())
            FP += int(((preds_flat == 1.0) & (masks_flat == 0.0)).sum().item())
            FN += int(((preds_flat == 0.0) & (masks_flat == 1.0)).sum().item())

            total_pixels += preds_flat.numel()

    # усреднённый лосс по батчам
    val_loss = val_loss / max(1, len(dataloader))

    # F1 (без деления на ноль)
    denom = 2 * TP + FP + FN
    if denom == 0:
        # нет позитивных пикселей ни в GT, ни в предсказаниях -> F1 не определён
        val_f1 = float("nan")
    else:
        val_f1 = 2.0 * TP / denom

    return val_loss, val_f1


ROI_SIZE = (128, 128, 128)
LEARNING_RATE = 1e-4
ROOT_DIR = "./checkpoints"
SW_BATCH_SIZE = 2
OVERLAP = 0.25

MODEL_PARAMS = dict(
    patch_size=(2, 2, 2),
    depths=(2, 2, 2, 1),
    num_heads=(3, 6, 12, 24),
    window_size=(7, 7, 7),
    qkv_bias=True,
    mlp_ratio=4.0,
    feature_size=48,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.1,
    patch_norm=True,
    spatial_dims=3,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 1

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)


# детерминированный рандом-генератор для reproducibility
generator = torch.Generator()
generator.manual_seed(SEED)

#train_ds, val_ds, out = torch.utils.data.random_split(ds, [train_len, val_len, out_len], generator=generator)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)

model = build_swinunetr(MODEL_PARAMS, in_ch=1, out_ch=1)

#--- подгрузка чек поинта
#ckpt_path = "/content/checkpoint.pth"
optimizer = optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

#ckpt = load_checkpoint(ckpt_path, model, optimizer=optimizer, device=DEVICE, strict=False, verbose=True)
# после загрузки
#model_to_inspect = model.module if isinstance(model, nn.DataParallel) else model
#print("Model param count:", sum(p.numel() for p in model_to_inspect.parameters()))
# сравнить ключи чекпойнта и модели
#ckpt_state = ckpt.get("model_state_dict", ckpt.get("state_dict", ckpt))
#ckpt_keys = set(k.replace("module.", "") for k in ckpt_state.keys())
#model_keys = set(model_to_inspect.state_dict().keys())
#print("keys in ckpt but not in model:", sorted(list(ckpt_keys - model_keys))[:20])
#print("keys in model but not in ckpt:", sorted(list(model_keys - ckpt_keys))[:20])
#--- подгрузка чек поинта

loss_fn = CombinedLoss()
best_val_dice = -1.0

train_losses = []
val_losses = []
val_dices = []
val_f1s = []
val_aps = []
val_ious = []

best_val_dice = -1.0
best_val_f1 = 0

os.makedirs(ROOT_DIR, exist_ok=True)
loss_plot_path = os.path.join(ROOT_DIR, "training_loss.png")

for epoch in range(1, 2):
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, epoch)

    val_loss, val_f1 = validate(model, val_loader)

    print(
        f"Epoch {epoch} | Train loss: {train_loss:.4f} | Val loss: {val_loss:.4f} | "
        f"F1: {val_f1 if not np.isnan(val_f1) else 'nan'}"
    )

    model_to_save = model.module if isinstance(model, nn.DataParallel) else model
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model_to_save.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "val_f1": val_f1,
        "val_loss": val_loss,
    }
    torch.save(checkpoint, os.path.join(ROOT_DIR, f"checkpoint_epoch_{epoch}.pth"))

    # обновим лучший по F1
    if (not np.isnan(val_f1) and val_f1 > best_val_f1):
        best_val_f1 = val_f1
        torch.save(checkpoint, os.path.join(ROOT_DIR, "best_checkpoint.pth"))
        print("Best model updated:", best_val_f1)
         # обновляем исторические списки
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    val_f1s.append(val_f1)

    # рисуем и сохраняем график loss (train + val)
    try:
        plt.figure(figsize=(8, 5))
        epochs_range = list(range(1, len(train_losses) + 1))
        plt.plot(epochs_range, train_losses, label="Train Loss")
        plt.plot(epochs_range, val_losses, label="Val Loss")
        plt.title("Training and Validation Loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        # сохраняем в верхнюю директорию (ROOT_DIR)
        plt.savefig(loss_plot_path)
        plt.show()  # в ноутбуке Kaggle покажет график
        plt.close()
    except Exception as e:
        warnings.warn(f"Failed to plot/save loss figure: {e}")

print("Training finished. Best val f1:", best_val_f1)

Train E1:   2%|▏         | 31/1596 [01:39<1:21:39,  3.13s/it, loss=1.64]

----