# Lab 02: Segmentation

In this laboratory work you will create pipeline for cancer cells segmentation starting from reading data to preprocessing, creating training setup, experimenting with models.

## Part 1: Reading dataset

Write Dataset class inheriting regular `torch` dataset.

In this task we use small datset just to make this homework accessible for everyone, so please **do not** read all the data in constructor because it is not how it works for real life datasets. You need to read image from disk only when it is requesed (getitem).

Split data (persistently between runs) to train, val and test sets. Add corresponding parameter to dataset constructor.

In [None]:
import os
if not os.path.exists('breast-cancer-cells-segmentation.zip'):
    !curl -JLO 'https://www.dropbox.com/scl/fi/gs3kzp6b8k6faf667m5tt/breast-cancer-cells-segmentation.zip?rlkey=md3mzikpwrvnaluxnhms7r4zn'
    !unzip breast-cancer-cells-segmentation.zip
else:
    print('Dataset is already downloaded')

## Part 1.1: Analyzing dataset

Each time you build model you first should make EDA to understand your data.

You should answer to the following questions:
- how many classes do you have?
- what is class balance?
- how many cells (roughly) do you have in train data?

Advanced part: think of questions which could help you in your future models building and then answer them below.

In [None]:
import os
import glob
import cv2
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from collections import Counter
from tqdm.notebook import tqdm
import random

In [None]:
IMG_DIR = 'Images'
MASK_DIR = 'Masks'

In [None]:
all_files_with_type = []
image_files = sorted(glob.glob(os.path.join(IMG_DIR, '*_ccd.tif')))
mask_files_set = set(sorted(glob.glob(os.path.join(MASK_DIR, '*.TIF'))))

for img_path in image_files:
    basename = os.path.basename(img_path)
    parts = basename.replace('_ccd.tif', '').split('_')
    if len(parts) >= 2:
        tumor_type = parts[-1]
    else:
        print(f"Warning: Could not extract tumor type from {basename}")
        tumor_type = "unknown"

    mask_basename = basename.replace('_ccd.tif', '.TIF')
    expected_mask_path = os.path.join(MASK_DIR, mask_basename)

    if expected_mask_path in mask_files_set:
        all_files_with_type.append((img_path, expected_mask_path, tumor_type))
    else:
        print(f"Warning: Mask not found for image {img_path}")

print(f"Найдено {len(image_files)} изображений.")
print(f"Найдено {len(mask_files_set)} масок.")
print(f"Создано {len(all_files_with_type)} пар изображение/маска с типом опухоли.")

random.seed(42)
random.shuffle(all_files_with_type)

In [None]:
tumor_type_counts = Counter()
for _, _, tumor_type in all_files_with_type:
    tumor_type_counts[tumor_type] += 1

print("Распределение типов опухолей во всем датасете:")
for tumor_type, count in tumor_type_counts.most_common():
    print(f"- {tumor_type}: {count}")

total_images = len(all_files_with_type)
print(f"Всего изображений: {total_images}")
num_tumor_types = len(tumor_type_counts)
print(f"Количество уникальных типов опухолей: {num_tumor_types}")

In [None]:
types = list(tumor_type_counts.keys())
counts = list(tumor_type_counts.values())

sorted_indices = np.argsort(types)
types = np.array(types)[sorted_indices]
counts = np.array(counts)[sorted_indices]

plt.figure(figsize=(10, 6))
plt.bar(types, counts)
plt.xlabel("Тип опухоли")
plt.ylabel("Количество изображений")
plt.title("Распределение типов опухолей в датасете")
plt.xticks(rotation=45, ha='right')
plt.tight_layout()
plt.show()

In [None]:
train_val_data, test_data = train_test_split(
    all_files_with_type,
    test_size=0.15,
    random_state=42
)

train_val_types = [item[2] for item in train_val_data]
type_counts_train_val = Counter(train_val_types)
can_stratify_val = all(count > 1 for count in type_counts_train_val.values())

val_size_fraction = 0.15 / 0.85
stratify_val = train_val_types if can_stratify_val else None

train_data, val_data = train_test_split(
    train_val_data,
    test_size=val_size_fraction,
    random_state=42,
    stratify=stratify_val
)

print(f"Размер обучающей выборки: {len(train_data)}")
print(f"Размер валидационной выборки: {len(val_data)}")
print(f"Размер тестовой выборки: {len(test_data)}")


print("\nРаспределение типов в обучающей выборке:")
print(Counter(item[2] for item in train_data))
print("\nРаспределение типов в валидационной выборке:")
print(Counter(item[2] for item in val_data))
print("\nРаспределение типов в тестовой выборке:")
print(Counter(item[2] for item in test_data))

In [None]:
num_segmentation_classes = 2
foreground_class_value = 255
background_class_value = 0

print(f"Задача сегментации: {num_segmentation_classes} класса (Фон: {background_class_value}, Клетки: {foreground_class_value}).")

total_pixels = 0
foreground_pixels = 0

print("\nПодсчет баланса классов сегментации по пикселям в обучающей выборке...")
for _, mask_path, _ in tqdm(train_data):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        print(f"Warning: Не удалось прочитать маску {mask_path}")
        continue

    foreground_pixels += np.sum(mask == foreground_class_value)
    total_pixels += mask.size

background_pixels = 0
if total_pixels > 0:
    background_pixels = total_pixels - foreground_pixels
    print(f"Всего пикселей в обучающих масках: {total_pixels}")
    print(f"Пикселей клеток: {foreground_pixels} ({foreground_pixels / total_pixels:.4f})")
    print(f"Пикселей фона: {background_pixels} ({background_pixels / total_pixels:.4f})")
    print(f"Соотношение фон/клетки примерно: {background_pixels/foreground_pixels:.1f} : 1")
else:
    print("Не удалось посчитать пиксели.")

total_cells = 0
min_cell_area = 10

print("\nПодсчет примерного количества клеток в обучающей выборке...")
for _, mask_path, _ in tqdm(train_data):
    mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
    if mask is None:
        continue

    binary_mask = (mask == foreground_class_value).astype(np.uint8)
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, connectivity=8)

    cells_in_mask = 0
    for i in range(1, num_labels):
         if stats[i, cv2.CC_STAT_AREA] >= min_cell_area:
             cells_in_mask += 1

    total_cells += cells_in_mask

print(f"Примерное общее количество клеток в обучающей выборке: {total_cells}")
if len(train_data) > 0:
    print(f"Среднее количество клеток на изображение (train): {total_cells / len(train_data):.1f}")

In [None]:
print(f"Анализ на уровне изображений:")
print(f"- Количество уникальных типов опухолей (по именам файлов): {num_tumor_types}")
print(f"- Распределение по типам (весь датасет): см. график выше и счетчики.")
print(f"- Общее количество изображений: {total_images}")
print(f"- Разделение на выборки: Train={len(train_data)}, Val={len(val_data)}, Test={len(test_data)}")

print(f"\nАнализ для задачи сегментации (на уровне пикселей):")
print(f"- Количество классов сегментации: {num_segmentation_classes} (Фон: {background_class_value}, Клетки: {foreground_class_value})")
if total_pixels > 0:
    print(f"- Баланс классов сегментации (попиксельно, train): ~{background_pixels / total_pixels:.2%} фон / {foreground_pixels / total_pixels:.2%} клетки")
else:
     print("- Баланс классов сегментации: Не посчитан.")

if total_cells > 0:
     print(f"- Примерное количество объектов (клеток) в обучающей выборке: {total_cells}")
     print(f"- Среднее количество клеток на изображение (train): {total_cells / len(train_data):.1f}")
else:
     print("- Примерное количество клеток: Не посчитано.")

print("---------------------------------")

## Part 2: Unet model

Implement class of Unet model according with [the original paper](https://arxiv.org/pdf/1505.04597).
Ajust size of the network according with your input data.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA доступна. Используется GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA недоступна. Используется CPU.")

In [None]:
# your code here

In [None]:
class DoubleConv(nn.Module):

    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class Down(nn.Module):

    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool2d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)

class Up(nn.Module):

    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()

        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])

        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class OutConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)

    def forward(self, x):
        return self.conv(x)

In [None]:
class UNet(nn.Module):
    def __init__(self, n_channels, n_classes, bilinear=True):
        super(UNet, self).__init__()
        self.n_channels = n_channels
        self.n_classes = n_classes
        self.bilinear = bilinear

        self.inc = DoubleConv(n_channels, 64)
        self.down1 = Down(64, 128)
        self.down2 = Down(128, 256)
        self.down3 = Down(256, 512)
        factor = 2 if bilinear else 1
        self.down4 = Down(512, 1024 // factor)
        self.up1 = Up(1024, 512 // factor, bilinear)
        self.up2 = Up(512, 256 // factor, bilinear)
        self.up3 = Up(256, 128 // factor, bilinear)
        self.up4 = Up(128, 64, bilinear)
        self.outc = OutConv(64, n_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        logits = self.outc(x)
        return logits

In [None]:
input_channels = 3
num_classes = 2

model = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)

model.to(device)

print(f"Модель U-Net создана и перемещена на {device}.")

In [None]:
from torch.utils.data import Dataset
import albumentations as A

In [None]:
class CancerCellDataset(Dataset):
    def __init__(self, data_list, foreground_value=255, target_size=(512, 512)):
        self.data_list = data_list
        self.foreground_value = foreground_value
        self.target_size = target_size

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        img_path, mask_path, _ = self.data_list[idx]

        try:
            image = cv2.imread(img_path, cv2.IMREAD_COLOR)
            if image is None:
                raise IOError(f"Не удалось загрузить изображение: {img_path}")

            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
            if mask is None:
                raise IOError(f"Не удалось загрузить маску: {mask_path}")

            if self.target_size:
                image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
                mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)

            image = image.astype(np.float32) / 255.0
            mask = (mask == self.foreground_value).astype(np.int64)

            image_tensor = torch.from_numpy(image.transpose((2, 0, 1)))
            mask_tensor = torch.from_numpy(mask)

            return image_tensor, mask_tensor

        except Exception as e:
            print(f"Ошибка при обработке индекса {idx}, путь к изображению: {img_path}")
            print(e)
            raise e

In [None]:
TARGET_H, TARGET_W = 256, 256
FG_VALUE = 255

train_dataset = CancerCellDataset(train_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W))
val_dataset = CancerCellDataset(val_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W))
test_dataset = CancerCellDataset(test_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W))

print(f"Датасеты:")
print(f"- Обучающий: {len(train_dataset)} сэмплов")

img, msk = train_dataset[0]
print("\nПример элемента из датасета:")
print(f"- Тип изображения: {img.dtype}, Форма: {img.shape}")
print(f"- Тип маски: {msk.dtype}, Форма: {msk.shape}")
assert img.shape[0] == 3, "Изображение должно иметь 3 канала"

In [None]:
import torch.optim as optim
import time
import copy
from torch.utils.data import DataLoader

In [None]:
NUM_EPOCHS = 25
BATCH_SIZE = 3
LEARNING_RATE = 1e-3
NUM_WORKERS = 2
WEIGHT_DECAY = 1e-5

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
print("Пересозданы DataLoader'ы.")

In [None]:
import torch.nn.functional as F



criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)

print("Функция потерь и оптимизатор определены.")

In [None]:
def train_model(model, criterion, optimizer, train_loader, val_loader, device, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = float('inf')

    history = {'train_loss': [], 'val_loss': []}

    for epoch in range(num_epochs):
        print(f'Эпоха {epoch+1}/{num_epochs}')
        print('-' * 10)

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Learning Rate: {current_lr:.1e}")

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                dataloader = train_loader
            else:
                model.eval()
                dataloader = val_loader

            running_loss = 0.0
            processed_samples = 0

            pbar = tqdm(dataloader, desc=f"{phase.capitalize()} Эпоха {epoch+1}")
            for inputs, labels in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                processed_samples += inputs.size(0)
                pbar.set_postfix({'loss': running_loss / processed_samples})


            epoch_loss = running_loss / len(dataloader.dataset)

            print(f'{phase} Потеря: {epoch_loss:.4f}')

            if phase == 'val':
                 history['val_loss'].append(epoch_loss)
                 if epoch_loss < best_loss:
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())
                    print("Сохранена лучшая модель (по валидационной потере)")
            else:
                 history['train_loss'].append(epoch_loss)

        print()

    time_elapsed = time.time() - since
    print(f'Обучение завершено за {time_elapsed // 60:.0f}м {time_elapsed % 60:.0f}с')
    print(f'Лучшая валидационная потеря: {best_loss:4f}')

    model.load_state_dict(best_model_wts)
    return model, history

try:
    model_ft, history = train_model(
        model,
        criterion,
        optimizer,
        train_loader,
        val_loader,
        device,
        num_epochs=NUM_EPOCHS
    )
    print("Обучение с аугментациями и планировщиком завершено.")
except NameError as e:
     print(f"Ошибка: Необходимая переменная не найдена ({e}). Убедись, что модель, оптимизатор, лосс и т.д. созданы.")
except Exception as e:
    print(f"Ошибка во время обучения: {e}")

In [None]:
import matplotlib.pyplot as plt

def plot_history(history):
    plt.figure(figsize=(10, 5))
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.xlabel('Эпоха')
    plt.ylabel('Потеря (CrossEntropy)')
    plt.title('История обучения')
    plt.legend()
    plt.grid(True)
    plt.show()

plot_history(history)

In [None]:
save_path = 'best_unet_model_ce_loss.pth'
torch.save(model_ft.state_dict(), save_path)
print(f"модель сохронена: {save_path}")

In [None]:
model_ft.eval()
eval_model = model_ft


In [None]:
all_inputs = []
all_preds = []
all_labels = []


print("Начинаем оценку на тестовой выборке...")
with torch.no_grad():
    pbar = tqdm(test_loader, desc="Тестирование")
    for inputs, labels in pbar:
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = eval_model(inputs)

        probabilities = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probabilities, dim=1)

        all_inputs.append(inputs.cpu().numpy())
        all_preds.append(preds.cpu().numpy())
        all_labels.append(labels.cpu().numpy())

if all_inputs:
    all_inputs = np.concatenate(all_inputs, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    print(f"Оценка завершена. Получено {len(all_preds)} предсказаний.")
    print(f"Форма массива изображений: {all_inputs.shape}")
    print(f"Форма массива предсказаний: {all_preds.shape}")
    print(f"Форма массива истинных масок: {all_labels.shape}")
else:
    print("Оценка не была проведена или не получено результатов.")

In [None]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.astype(bool)
    target = target.astype(bool)
    intersection = np.sum(pred & target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

def iou_score(pred, target, smooth=1e-6):
    pred = pred.astype(bool)
    target = target.astype(bool)
    intersection = np.sum(pred & target)
    union = np.sum(pred | target)
    return (intersection + smooth) / (union + smooth)

dice_scores = []
iou_scores = []



foreground_class_index = 1
for i in range(len(all_preds)):
    pred_mask = (all_preds[i] == foreground_class_index)
    true_mask = (all_labels[i] == foreground_class_index)

    dice = dice_coefficient(pred_mask, true_mask)
    iou = iou_score(pred_mask, true_mask)

    dice_scores.append(dice)
    iou_scores.append(iou)

avg_dice = np.mean(dice_scores)
avg_iou = np.mean(iou_scores)

print(f"--- Метрики на тестовой выборке ---")
print(f"Средний Dice Coefficient: {avg_dice:.4f}")
print(f"Средний IoU (Jaccard Index): {avg_iou:.4f}")
print(f"---------------------------------")

In [None]:
import gc

num_samples_to_show = 0
if 'all_inputs' in locals() and len(all_inputs) > 0:
     num_samples_to_show = min(5, len(all_inputs))

if num_samples_to_show > 0:
    print("\n--- Примеры сегментации на тестовой выборке ---")
    plt.figure(figsize=(12, num_samples_to_show * 4))

    for i in range(num_samples_to_show):
        input_img = all_inputs[i].transpose((1, 2, 0))
        pred_mask = all_preds[i]
        true_mask = all_labels[i]

        plt.subplot(num_samples_to_show, 3, i * 3 + 1)
        plt.imshow(input_img)
        plt.title(f"Input Image #{i+1}")
        plt.axis('off')

        plt.subplot(num_samples_to_show, 3, i * 3 + 2)
        plt.imshow(pred_mask, cmap='gray')
        dice_val = dice_scores[i] if i < len(dice_scores) else float('nan')
        iou_val = iou_scores[i] if i < len(iou_scores) else float('nan')
        plt.title(f"Predicted Mask #{i+1}\nDice: {dice_val:.3f}, IoU: {iou_val:.3f}")
        plt.axis('off')

        plt.subplot(num_samples_to_show, 3, i * 3 + 3)
        plt.imshow(true_mask, cmap='gray')
        plt.title(f"Ground Truth Mask #{i+1}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
else:
    print("Нет данных для визуализации.")



del all_inputs, all_preds, all_labels, dice_scores, iou_scores
del pred_mask, true_mask, input_img
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()

### Результат:
Модель научилась просто выдавать черную маску так как это дает хороший результат, надо попробовать другую функцию потерь в следующей части чтобы отучить ее от этого.

## Part 3: Unet training with different losses

Train model in three setups:
- Crossentropy loss
- Dice loss
- Composition of CE and Dice

Advanced:\
For training procedure use one of frameworks for models training - Lightning, (Hugging Face, Catalyst, Ignite).\
_Hint: this will make your life easier!_

Save all three trained models to disk!

Use validation set to evaluate models.

In [None]:
# your code here

class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, logits, targets):
        probs = F.softmax(logits, dim=1)
        probs_fg = probs[:, 1]
        targets_fg = (targets == 1).float()

        probs_flat = probs_fg.view(probs_fg.size(0), -1)
        targets_flat = targets_fg.view(targets_fg.size(0), -1)

        intersection = (probs_flat * targets_flat).sum(1)
        pred_sum = probs_flat.sum(1)
        target_sum = targets_flat.sum(1)

        dice = (2. * intersection + self.smooth) / (pred_sum + target_sum + self.smooth)
        return 1 - dice.mean()

class CombinedLoss(nn.Module):
    def __init__(self, weight_ce=0.5, weight_dice=0.5, dice_smooth=1e-6, ce_weight=None):
        super(CombinedLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss(weight=ce_weight)
        self.dice_loss = DiceLoss(smooth=dice_smooth)
        self.weight_ce = weight_ce
        self.weight_dice = weight_dice
        print(f"CombinedLoss: weight_ce={weight_ce}, weight_dice={weight_dice}")

    def forward(self, logits, targets):
        ce = self.ce_loss(logits, targets)
        dice = self.dice_loss(logits, targets)
        loss = self.weight_ce * ce + self.weight_dice * dice
        return loss

print("Классы DiceLoss и CombinedLoss определены.")



In [None]:
print("\n--- Начинаем обучение с Dice Loss ---")


model_dice = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)
model_dice.to(device)


optimizer_dice = optim.Adam(model_dice.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)


criterion_dice = DiceLoss()


try:
    model_dice_ft, history_dice = train_model(
        model_dice,
        criterion_dice,
        optimizer_dice,
        train_loader,
        val_loader,
        device,
        num_epochs=NUM_EPOCHS
    )
    print("\n--- Обучение с Dice Loss завершено ---")


    save_path_dice = 'unet_dice_loss.pth'
    torch.save(model_dice_ft.state_dict(), save_path_dice)
    print(f"Модель сохранена в: {save_path_dice}")


    plot_history(history_dice)

except Exception as e:
    print(f"Ошибка во время обучения с Dice Loss: {e}")


del model_dice, optimizer_dice, criterion_dice
if 'model_dice_ft' in locals(): del model_dice_ft
if torch.cuda.is_available(): torch.cuda.empty_cache()

In [None]:
print("\n--- Начинаем обучение с Combined Loss (CE + Dice) ---")


model_combined = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)
model_combined.to(device)


optimizer_combined = optim.Adam(model_combined.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)





criterion_combined = CombinedLoss(weight_ce=0.5, weight_dice=0.5)


try:
    model_combined_ft, history_combined = train_model(
        model_combined,
        criterion_combined,
        optimizer_combined,
        train_loader,
        val_loader,
        device,
        num_epochs=NUM_EPOCHS
    )
    print("\n--- Обучение с Combined Loss завершено ---")


    save_path_combined = 'unet_combined_loss.pth'
    torch.save(model_combined_ft.state_dict(), save_path_combined)
    print(f"Модель сохранена в: {save_path_combined}")


    plot_history(history_combined)

except Exception as e:
    print(f"Ошибка во время обучения с Combined Loss: {e}")


del model_combined, optimizer_combined, criterion_combined
if 'model_combined_ft' in locals(): del model_combined_ft
if torch.cuda.is_available(): torch.cuda.empty_cache()

In [None]:
def dice_coefficient(pred, target, smooth=1e-6):
    pred = pred.astype(bool)
    target = target.astype(bool)
    intersection = np.sum(pred & target)
    return (2. * intersection + smooth) / (np.sum(pred) + np.sum(target) + smooth)

def iou_score(pred, target, smooth=1e-6):
    pred = pred.astype(bool)
    target = target.astype(bool)
    intersection = np.sum(pred & target)
    union = np.sum(pred | target)
    return (intersection + smooth) / (union + smooth)


def evaluate_model(model_path, val_loader, device):
    print(f"\nОценка модели: {model_path}")

    model = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)
    try:
        model.load_state_dict(torch.load(model_path, map_location=device))
    except FileNotFoundError:
        print(f"Ошибка: Файл {model_path} не найден!")
        return None, None, None, None
    except Exception as e:
         print(f"Ошибка при загрузке модели {model_path}: {e}")
         return None, None, None, None

    model.to(device)
    model.eval()

    all_inputs, all_preds, all_labels = [], [], []
    with torch.no_grad():
        pbar = tqdm(val_loader, desc=f"Оценка {os.path.basename(model_path)}")
        for inputs, labels in pbar:
            inputs = inputs.to(device)
            labels_cpu = labels.numpy()
            labels = labels.to(device)

            outputs = model(inputs)
            probabilities = torch.softmax(outputs, dim=1)
            preds = torch.argmax(probabilities, dim=1)

            all_inputs.append(inputs.cpu().numpy())
            all_preds.append(preds.cpu().numpy())
            all_labels.append(labels_cpu)

    if not all_preds:
        print("Нет данных для оценки.")
        return None, None, None, None

    all_inputs = np.concatenate(all_inputs, axis=0)
    all_preds = np.concatenate(all_preds, axis=0)
    all_labels = np.concatenate(all_labels, axis=0)

    dice_scores_val, iou_scores_val = [], []
    foreground_class_index = 1
    for i in range(len(all_preds)):
        pred_mask = (all_preds[i] == foreground_class_index)
        true_mask = (all_labels[i] == foreground_class_index)
        dice = dice_coefficient(pred_mask, true_mask)
        iou = iou_score(pred_mask, true_mask)
        dice_scores_val.append(dice)
        iou_scores_val.append(iou)

    avg_dice = np.mean(dice_scores_val)
    avg_iou = np.mean(iou_scores_val)

    print(f"Средний Dice на валидации: {avg_dice:.4f}")
    print(f"Средний IoU на валидации: {avg_iou:.4f}")


    del model, inputs, labels, outputs, probabilities, preds
    gc.collect()
    if torch.cuda.is_available(): torch.cuda.empty_cache()


    return avg_dice, avg_iou, (all_inputs, all_preds, all_labels, dice_scores_val, iou_scores_val)



model_paths = {
    "CE": "best_unet_model_ce_loss.pth",
    "Dice": "unet_dice_loss.pth",
    "Combined": "unet_combined_loss.pth"
}

results = {}
vis_data = {}


if 'val_loader' in locals():
    for name, path in model_paths.items():
        if os.path.exists(path):
             avg_dice, avg_iou, eval_results = evaluate_model(path, val_loader, device)
             if avg_dice is not None:
                 results[name] = {'Dice': avg_dice, 'IoU': avg_iou}
                 vis_data[name] = eval_results
        else:
             print(f"Файл для модели '{name}' ({path}) не найден, пропускаем оценку.")
else:
    print("Ошибка: val_loader не найден. Не могу выполнить оценку.")



print("\n--- Сводка результатов на валидационной выборке ---")
print("| Модель (Loss) | Средний Dice | Средний IoU  |")
print("|---------------|--------------|--------------|")
for name, metrics in results.items():
    print(f"| {name:<13} | {metrics['Dice']:.4f}       | {metrics['IoU']:.4f}       |")
print("----------------------------------------------------")



for model_name, eval_results in vis_data.items():
    print(f"\n--- Визуализация для модели: {model_name} ---")


    all_inputs, all_preds, all_labels, dice_scores_val, iou_scores_val = eval_results


    num_samples_to_show = min(5, len(all_inputs))

    if num_samples_to_show == 0:
        print("Нет данных для визуализации этой модели.")
        continue


    plt.figure(figsize=(12, num_samples_to_show * 4))

    plt.suptitle(f"Результаты модели: {model_name}", fontsize=16, y=1.02)


    for i in range(num_samples_to_show):
        input_img = all_inputs[i].transpose((1, 2, 0))
        pred_mask = all_preds[i]
        true_mask = all_labels[i]


        plt.subplot(num_samples_to_show, 3, i * 3 + 1)
        plt.imshow(input_img)
        plt.title(f"Input Image #{i+1}")
        plt.axis('off')


        plt.subplot(num_samples_to_show, 3, i * 3 + 2)
        plt.imshow(pred_mask, cmap='gray')
        dice_val = dice_scores_val[i] if i < len(dice_scores_val) else float('nan')
        iou_val = iou_scores_val[i] if i < len(iou_scores_val) else float('nan')

        plt.title(f"Predicted Mask ({model_name})\nDice: {dice_val:.3f}, IoU: {iou_val:.3f}")
        plt.axis('off')


        plt.subplot(num_samples_to_show, 3, i * 3 + 3)
        plt.imshow(true_mask, cmap='gray')
        plt.title(f"Ground Truth Mask #{i+1}")
        plt.axis('off')


    plt.tight_layout(rect=[0, 0.03, 1, 0.98])
    plt.show()


del results, vis_data
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()

## Part 3.1: Losses conclusion

Analyse results of the three models above using metrics, losses and visualizations you know (all three parts are required).

Make motivated conclusion on which setup is better. Provide your arguments.

Calculate loss and metrics of the best model on test set.

## Результаты
Ну с новыми функциями потерь дело явно бодрее идет, но модель если тренить много эпох из-за маленького разнобрая (40 семплов) капец перееобучается и трейн начинает пытаться учить
НО лучше всего себя Dice показывает

In [None]:
# your code here

## Part 4: Augmentations and advanced model

Choose set of augmentations relevant for this case (at least 5 of them) using [Albumentations library](https://albumentations.ai/).
Apply them to dataset (of course dynamicaly during reading from disk).

One more thing to improve is model: use [PSPnet](https://arxiv.org/pdf/1612.01105v2) (either use library [implementation](https://smp.readthedocs.io/en/latest/models.html#pspnet) or implement yourself) as improved version of Unet.

Alternatively you may use model of your choice (it should be more advanced than Unet ofc).

Train Unet and second model on augmented data.

In [None]:
# your code here

class CancerCellDatasetAug(Dataset):
    def __init__(self, data_list, foreground_value=255, target_size=(512, 512), transform=None):
        self.data_list = data_list
        self.foreground_value = foreground_value
        self.target_size = target_size
        self.transform = transform

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, idx):
        img_path, mask_path, _ = self.data_list[idx]

        try:
            image = cv2.imread(img_path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

            if image is None: raise IOError(f"Не удалось загрузить изображение: {img_path}")
            if mask is None: raise IOError(f"Не удалось загрузить маску: {mask_path}")

            if self.target_size:
                image = cv2.resize(image, self.target_size, interpolation=cv2.INTER_LINEAR)
                mask = cv2.resize(mask, self.target_size, interpolation=cv2.INTER_NEAREST)


            mask = (mask == self.foreground_value).astype(np.uint8)

            if self.transform:
                augmented = self.transform(image=image, mask=mask)
                image = augmented['image']
                mask = augmented['mask']


            image = image.astype(np.float32) / 255.0


            image_tensor = torch.from_numpy(image.transpose((2, 0, 1)))
            mask_tensor = torch.from_numpy(mask).long()

            return image_tensor, mask_tensor

        except Exception as e:
            print(f"Ошибка при обработке индекса {idx}, путь к изображению: {img_path}")
            print(e)
            raise e

In [None]:
train_transform = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.06, scale_limit=0.1, rotate_limit=15, p=0.7,
                       border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.6),
    A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.6),
])


val_test_transform = None

In [None]:
train_dataset_aug = CancerCellDatasetAug(train_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W), transform=train_transform)
val_dataset_aug = CancerCellDatasetAug(val_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W), transform=val_test_transform)
test_dataset_aug = CancerCellDatasetAug(test_data, foreground_value=FG_VALUE, target_size=(TARGET_H, TARGET_W), transform=val_test_transform)

print(f"Созданы датасеты (train с аугментацией, размер {TARGET_H}x{TARGET_W}):")
print(f"- Обучающий: {len(train_dataset_aug)} сэмплов")
print(f"- Валидационный: {len(val_dataset_aug)} сэмплов")
print(f"- Тестовый: {len(test_dataset_aug)} сэмплов")

train_loader_aug = DataLoader(train_dataset_aug, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader_aug = DataLoader(val_dataset_aug, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
test_loader_aug = DataLoader(test_dataset_aug, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)

print("Созданы DataLoader'ы.")

In [None]:
import torch.optim.lr_scheduler as lr_scheduler

model_dice_aug = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)
model_dice_aug.to(device)

criterion_dice_aug = DiceLoss()
print("Используется Dice Loss.")

optimizer_dice_aug = optim.Adam(model_dice_aug.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
print("Создан оптимизатор Adam.")

scheduler_dice_aug = lr_scheduler.ReduceLROnPlateau(optimizer_dice_aug, mode='min', factor=0.1, patience=10, verbose=True)
print("Создан планировщик ReduceLROnPlateau.")

In [None]:
try:
    model_dice_aug_ft, history_dice_aug = train_model(
        model_dice_aug,
        criterion_dice_aug,
        optimizer_dice_aug,
        train_loader_aug,
        val_loader_aug,
        device,
        num_epochs=NUM_EPOCHS
    )
    print("\n--- Обучение с Dice Loss и Аугментациями завершено ---")


    save_path_dice_aug = 'unet_dice_augmented.pth'
    torch.save(model_dice_aug_ft.state_dict(), save_path_dice_aug)
    print(f"Модель сохранена в: {save_path_dice_aug}")



    plot_history(history_dice_aug)

except Exception as e:
    print(f"Ошибка во время обучения: {e}")

In [None]:
print("\n--- Оценка модели (Dice + Augmentation)---")


eval_model = model_dice_aug_ft
eval_model.eval()
test_inputs, test_preds, test_labels = [], [], []
with torch.no_grad():
    pbar = tqdm(test_loader_aug, desc="Тестирование (Dice + Aug)")
    for inputs, labels in pbar:
        inputs = inputs.to(device)
        labels_cpu = labels.numpy()
        labels = labels.to(device)

        outputs = eval_model(inputs)
        probabilities = torch.softmax(outputs, dim=1)
        preds = torch.argmax(probabilities, dim=1)

        test_inputs.append(inputs.cpu().numpy())
        test_preds.append(preds.cpu().numpy())
        test_labels.append(labels_cpu)

test_inputs = np.concatenate(test_inputs, axis=0)
test_preds = np.concatenate(test_preds, axis=0)
test_labels = np.concatenate(test_labels, axis=0)

dice_scores_test, iou_scores_test = [], []
foreground_class_index = 1
for i in range(len(test_preds)):
    pred_mask = (test_preds[i] == foreground_class_index)
    true_mask = (test_labels[i] == foreground_class_index)
    dice = dice_coefficient(pred_mask, true_mask)
    iou = iou_score(pred_mask, true_mask)
    dice_scores_test.append(dice)
    iou_scores_test.append(iou)

avg_dice_test = np.mean(dice_scores_test)
avg_iou_test = np.mean(iou_scores_test)

print(f"\n--- Метрики (Dice + Augmentation) ---")
print(f"Средний Dice Coefficient: {avg_dice_test:.4f}")
print(f"Средний IoU (Jaccard Index): {avg_iou_test:.4f}")
print(f"-------------------------------------------------------")


print("\n--- Примеры сегментации (Dice + Augmentation) ---")
num_samples_to_show = min(5, len(test_inputs))
if num_samples_to_show > 0:
    plt.figure(figsize=(12, num_samples_to_show * 4))
    for i in range(num_samples_to_show):
        input_img = test_inputs[i].transpose((1, 2, 0))
        pred_mask = test_preds[i]
        true_mask = test_labels[i]

        plt.subplot(num_samples_to_show, 3, i * 3 + 1)
        plt.imshow(input_img)
        plt.title(f"Input Image #{i+1}")
        plt.axis('off')

        plt.subplot(num_samples_to_show, 3, i * 3 + 2)
        plt.imshow(pred_mask, cmap='gray')
        dice_val = dice_scores_test[i] if i < len(dice_scores_test) else float('nan')
        iou_val = iou_scores_test[i] if i < len(iou_scores_test) else float('nan')
        plt.title(f"Predicted Mask (Dice+Aug)\nDice: {dice_val:.3f}, IoU: {iou_val:.3f}")
        plt.axis('off')

        plt.subplot(num_samples_to_show, 3, i * 3 + 3)
        plt.imshow(true_mask, cmap='gray')
        plt.title(f"Ground Truth Mask #{i+1}")
        plt.axis('off')

    plt.tight_layout()
    plt.show()
else:
    print("Нет данных для визуализации.")

## Part 4.2: Augmentations and advanced model conclusion

Compare three setups:
- Unet without augmentations (with best loss)
- Unet with augmentations
- Advanced model with augmentations

_Hint: with augs and more complex model you may want to have more iterations._

Save all three trained models to disk!

Once again provide comprehensive arguments and your insights.

Wich setup is better?

Compute losses and metrics on test set. Measure improvement over first test evaluation.

In [None]:
# # Импорт модулей
# import segmentation_models_pytorch as smp
# import torch
# import albumentations as A
# from albumentations.pytorch import ToTensorV2

# # Аугментации (без нормализации под ImageNet)
# advanced_augmentations = A.Compose([
#     A.HorizontalFlip(p=0.5),
#     A.VerticalFlip(p=0.5),
#     A.RandomRotate90(p=0.5),
#     A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.8,
#                        border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
#     A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
#     A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=40, val_shift_limit=30, p=0.7),
#     A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
#     A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
#     ToTensorV2()
# ])

# # Создание датасетов
# train_dataset_psp = CancerCellDatasetAug(
#     train_data, 
#     foreground_value=FG_VALUE, 
#     target_size=(TARGET_H, TARGET_W), 
#     transform=advanced_augmentations
# )

# # Инициализация PSPNet со случайными весами
# psp_model = smp.PSPNet(
#     encoder_name="resnet18",        # Базовая архитектура
#     encoder_weights=None,           # Веса НЕ загружаем!
#     in_channels=3,                  
#     classes=num_classes,            
#     activation="softmax2d",         
#     psp_out_channels=256,           # Уменьшено для экономии памяти
#     psp_use_batchnorm=True,         
#     psp_dropout=0.1                 
# ).to(device)

# # Функция потерь (Dice + CE)
# class HybridLoss(nn.Module):
#     def __init__(self, alpha=0.5):
#         super().__init__()
#         self.ce = nn.CrossEntropyLoss()
#         self.dice = smp.losses.DiceLoss(mode="multiclass")
#         self.alpha = alpha

#     def forward(self, pred, target):
#         return self.alpha * self.ce(pred, target) + (1 - self.alpha) * self.dice(pred, target)

# # Оптимизатор (одинаковый lr для всех слоев)
# optimizer = torch.optim.AdamW(psp_model.parameters(), lr=1e-3)
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

# # Обучение (используйте вашу функцию train_model или аналогичную)
# # Обучение модели с распаковкой результатов
# psp_model, history_psp = train_model(
#     psp_model,
#     HybridLoss(alpha=0.5),
#     optimizer,
#     train_loader_aug,
#     val_loader_aug,
#     device,
#     num_epochs=30
# )

# # Визуализация истории обучения
# plot_history(history_psp)

# # Сохранение модели
# torch.save(psp_model.state_dict(), "pspnet_custom.pth")
# print("Модель успешно сохранена!")

In [None]:
import segmentation_models_pytorch as smp
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
from scipy.ndimage import distance_transform_edt

# Функция оценки метрик
def evaluate_model_metrics(model, loader, device):
    model.eval()
    dice_scores = []
    iou_scores = []
    
    with torch.no_grad():
        for images, masks in tqdm(loader, desc="Evaluation"):
            images = images.to(device)
            masks = masks.to(device)
            
            outputs = model(images)
            probs = F.softmax(outputs, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            # Convert to numpy arrays
            preds_np = preds.cpu().numpy()
            masks_np = masks.cpu().numpy()
            
            # Calculate metrics per image
            for i in range(preds_np.shape[0]):
                dice = dice_coefficient(preds_np[i], masks_np[i])
                iou = iou_score(preds_np[i], masks_np[i])
                
                dice_scores.append(dice)
                iou_scores.append(iou)
    
    return {
        'dice': np.mean(dice_scores),
        'iou': np.mean(iou_scores),
        'dice_scores': dice_scores,
        'iou_scores': iou_scores
    }

# Функция визуализации предсказаний
def visualize_predictions(model, dataset, device, num_samples=5):
    model.eval()
    indices = np.random.choice(len(dataset), num_samples, replace=False)
    
    plt.figure(figsize=(15, 3*num_samples))
    
    for i, idx in enumerate(indices):
        image, mask = dataset[idx]
        image_tensor = image.unsqueeze(0).to(device)
        
        with torch.no_grad():
            output = model(image_tensor)
            probs = F.softmax(output, dim=1)
            pred = torch.argmax(probs, dim=1).squeeze().cpu().numpy()
        
        # Denormalize image if needed
        image = image.cpu().numpy().transpose(1, 2, 0)
        if image.shape[2] > 3:
            image = image[..., :3]
        
        plt.subplot(num_samples, 3, i*3+1)
        plt.imshow(image)
        plt.title(f"Input {i+1}")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, i*3+2)
        plt.imshow(pred, cmap='gray')
        plt.title(f"Prediction {i+1}")
        plt.axis('off')
        
        plt.subplot(num_samples, 3, i*3+3)
        plt.imshow(mask.numpy(), cmap='gray')
        plt.title(f"Ground Truth {i+1}")
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

# Модифицированная функция обучения с историей
def train_advanced_model(model, criterion, optimizer, train_loader, val_loader, scheduler, device, num_epochs=30):
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_dice': [],
        'val_iou': []
    }
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_dice = 0.0
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)
        
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        for inputs, masks in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
            inputs = inputs.to(device)
            masks = masks.to(device)
            
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, masks)
            loss.backward()
            optimizer.step()
            
            epoch_train_loss += loss.item() * inputs.size(0)
        
        # Validation phase
        model.eval()
        epoch_val_loss = 0.0
        dice_scores = []
        iou_scores = []
        with torch.no_grad():
            for inputs, masks in tqdm(val_loader, desc="Validation"):
                inputs = inputs.to(device)
                masks = masks.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs, masks)
                epoch_val_loss += loss.item() * inputs.size(0)
                
                probs = F.softmax(outputs, dim=1)
                preds = torch.argmax(probs, dim=1)
                
                # Calculate metrics
                for i in range(preds.shape[0]):
                    pred = preds[i].cpu().numpy()
                    true = masks[i].cpu().numpy()
                    dice_scores.append(dice_coefficient(pred, true))
                    iou_scores.append(iou_score(pred, true))
        
        # Update history
        train_loss = epoch_train_loss / len(train_loader.dataset)
        val_loss = epoch_val_loss / len(val_loader.dataset)
        val_dice = np.mean(dice_scores)
        val_iou = np.mean(iou_scores)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_dice'].append(val_dice)
        history['val_iou'].append(val_iou)
        
        # Update scheduler
        scheduler.step(val_loss)
        
        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(model.state_dict(), 'best_psp_model.pth')
            print(f'New best model saved with Dice: {best_dice:.4f}')
        
        # Print epoch stats
        print(f'Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}')
        print(f'Val Dice: {val_dice:.4f} | Val IoU: {val_iou:.4f}\n')
    
    model.load_state_dict(best_model_wts)
    return model, history

# Улучшенные аугментации
advanced_augmentations = A.Compose([
    # A.HorizontalFlip(p=0.5),
    # A.ElasticTransform(alpha=120, sigma=120*0.05, alpha_affine=120*0.03, p=0.3),
    # A.RandomGamma(gamma_limit=(70, 130), p=0.5),
    # A.GaussianBlur(blur_limit=(3, 7), p=0.3),
    # A.MaskDropout(max_objects=5, mask_fill_value=0, p=0.5),  # Случайное удаление областей
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=30, p=0.8,
                       border_mode=cv2.BORDER_CONSTANT, value=0, mask_value=0),
    A.RandomBrightnessContrast(brightness_limit=0.3, contrast_limit=0.3, p=0.7),
    A.HueSaturationValue(hue_shift_limit=30, sat_shift_limit=40, val_shift_limit=30, p=0.7),
    A.GridDistortion(num_steps=5, distort_limit=0.3, p=0.5),
    A.CoarseDropout(max_holes=8, max_height=32, max_width=32, p=0.3),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

# Создание датасетов с улучшенными аугментациями
train_dataset_psp = CancerCellDatasetAug(
    train_data, 
    foreground_value=FG_VALUE, 
    target_size=(TARGET_H, TARGET_W), 
    transform=advanced_augmentations
)

# Инициализация PSPNet с предобученным энкодером
psp_model = smp.PSPNet(
    encoder_name="resnet50",        # Используем предобученный ResNet-50
    encoder_weights="imagenet",     # Веса от ImageNet
    in_channels=3,                  # RGB изображения
    classes=num_classes,            # Классы сегментации
    activation="softmax2d",         # Активация для многоклассовой сегментации
    psp_out_channels=512,           # Размерность выхода PSP модуля
    psp_use_batchnorm=True,         # Использовать BatchNorm
    psp_dropout=0.2                 # Регуляризация
).to(device)

# psp_model = smp.PSPNet(
#     encoder_name="resnet34",
#     encoder_weights="imagenet",
#     in_channels=3,
#     classes=num_classes,
#     activation="softmax2d",
#     psp_use_attention=True,
#     psp_dropout=0.5,
#     decoder_attention_type="scse"
# ).to(device)

# Комбинированная функция потерь
class HybridLoss(nn.Module):
    def __init__(self, alpha=0.7):
        super().__init__()
        self.ce = nn.CrossEntropyLoss()
        self.dice = smp.losses.DiceLoss(mode="multiclass")
        self.alpha = alpha

    def forward(self, pred, target):
        return self.alpha * self.ce(pred, target) + (1 - self.alpha) * self.dice(pred, target)

# class BoundaryLoss(nn.Module):
#     def __init__(self, epsilon=1e-5):
#         super().__init__()
#         self.epsilon = epsilon

#     def _one_hot(self, mask, num_classes):
#         return F.one_hot(mask.long(), num_classes).permute(0, 3, 1, 2).float()

#     def _compute_distance_map(self, mask):
#         mask_np = mask.cpu().numpy().astype(bool)
#         distance_map = np.zeros_like(mask_np, dtype=np.float32)
        
#         for b in range(mask_np.shape[0]):
#             for c in range(mask_np.shape[1]):
#                 pos_mask = mask_np[b, c].astype(bool)
#                 if pos_mask.any():
#                     neg_mask = ~pos_mask
#                     pos_dist = distance_transform_edt(neg_mask)
#                     neg_dist = distance_transform_edt(pos_mask)
#                     distance_map[b, c] = pos_dist - neg_dist
                
#         return torch.from_numpy(distance_map).to(mask.device)

#     def forward(self, pred, target):
#         num_classes = pred.shape[1]
#         target_oh = self._one_hot(target, num_classes)
#         target_dist = self._compute_distance_map(target_oh)
        
#         pred_soft = F.softmax(pred, dim=1)
#         loss = (pred_soft * target_dist).abs().mean()
        
#         return loss

# class HybridLoss(nn.Module):
#     def __init__(self, alpha=0.4, class_weights=[0.2, 0.8]):
#         super().__init__()
#         self.ce = nn.CrossEntropyLoss(
#             weight=torch.tensor(class_weights).to(device)
#         )
#         self.dice = smp.losses.DiceLoss(mode="multiclass", classes=[1])
#         self.focal = smp.losses.FocalLoss(mode="multiclass", gamma=2.0)
#         self.boundary = BoundaryLoss()
#         self.alpha = alpha

#     def forward(self, pred, target):
#         return (
#             0.3 * self.ce(pred, target) +
#             0.3 * self.dice(pred, target) +
#             0.2 * self.focal(pred, target) +
#             0.2 * self.boundary(pred, target)
#         )

# Настройка оптимизатора и планировщика
optimizer = torch.optim.AdamW([
    {'params': psp_model.encoder.parameters(), 'lr': 1e-4},  # Меньший lr для энкодера
    {'params': psp_model.decoder.parameters(), 'lr': 1e-3}    # Больший lr для декодера
])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=3, factor=0.5, verbose=True)

# # Модифицированная функция обучения
# def train_advanced_model(model, criterion, optimizer, train_loader, val_loader, scheduler, device, num_epochs=30):
#     best_model_wts = copy.deepcopy(model.state_dict())
#     best_dice = 0.0
    
#     for epoch in range(num_epochs):
#         print(f'Epoch {epoch+1}/{num_epochs}')
#         print('-' * 10)
        
#         # Фаза обучения
#         model.train()
#         running_loss = 0.0
#         pbar = tqdm(train_loader, desc=f"Train Epoch {epoch+1}")
#         for inputs, masks in pbar:
#             inputs = inputs.to(device)
#             masks = masks.to(device)
            
#             optimizer.zero_grad()
#             outputs = model(inputs)
#             loss = criterion(outputs, masks)
#             loss.backward()
#             optimizer.step()
            
#             running_loss += loss.item() * inputs.size(0)
#             pbar.set_postfix({'loss': running_loss/(pbar.n+1)})
        
#         # Фаза валидации
#         model.eval()
#         val_loss = 0.0
#         dice_scores = []
#         with torch.no_grad():
#             for inputs, masks in val_loader:
#                 inputs = inputs.to(device)
#                 masks = masks.to(device)
                
#                 outputs = model(inputs)
#                 loss = criterion(outputs, masks)
#                 val_loss += loss.item() * inputs.size(0)
                
#                 # Расчет Dice
#                 probs = torch.softmax(outputs, dim=1)
#                 preds = torch.argmax(probs, dim=1)
#                 dice = dice_coefficient(preds.cpu().numpy(), masks.cpu().numpy())
#                 dice_scores.extend(dice)
        
#         # Обновление планировщика
#         scheduler.step(val_loss)
        
#         # Сохранение лучшей модели
#         epoch_dice = np.mean(dice_scores)
#         if epoch_dice > best_dice:
#             best_dice = epoch_dice
#             best_model_wts = copy.deepcopy(model.state_dict())
#             torch.save(model.state_dict(), 'best_psp_model.pth')
#             print(f'New best model saved with Dice: {best_dice:.4f}')
    
#     print(f'Best Validation Dice: {best_dice:.4f}')
#     model.load_state_dict(best_model_wts)
#     return model

# # Запуск обучения
# psp_model, psp_model_history = train_advanced_model(
#     psp_model,
#     HybridLoss(alpha=0.6),
#     optimizer,
#     train_loader_aug,
#     val_loader_aug,
#     scheduler,
#     device,
#     num_epochs=40
# )

# model_dice_aug = UNet(n_channels=input_channels, n_classes=num_classes, bilinear=True)
# model_dice_aug.to(device)

criterion_dice_aug = DiceLoss()
print("Используется Dice Loss.")

optimizer_dice_aug = optim.Adam(psp_model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
print("Создан оптимизатор Adam.")

scheduler_dice_aug = lr_scheduler.ReduceLROnPlateau(optimizer_dice_aug, mode='min', factor=0.1, patience=10, verbose=True)
print("Создан планировщик ReduceLROnPlateau.")

# Запуск обучения
psp_model, psp_model_history = train_advanced_model(
    psp_model,
    criterion_dice_aug,
    optimizer_dice_aug,
    train_loader_aug,
    val_loader_aug,
    scheduler_dice_aug,
    device,
    num_epochs=40
)

# Функция для отображения истории обучения
def plot_history(history):
    plt.figure(figsize=(12, 5))
    
    # Loss plot
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    # Metrics plot
    plt.subplot(1, 2, 2)
    plt.plot(history['val_dice'], label='Val Dice')
    plt.plot(history['val_iou'], label='Val IoU')
    plt.title('Validation Metrics')
    plt.xlabel('Epoch')
    plt.ylabel('Score')
    plt.legend()
    
    plt.tight_layout()
    plt.show()

# Оценка на тестовых данных
test_metrics = evaluate_model_metrics(psp_model, test_loader_aug, device)
print(f"Test Dice: {test_metrics['dice']:.4f}, Test IoU: {test_metrics['iou']:.4f}")


# Визуализация истории обучения
plot_history(psp_model_history)

# Сохранение модели
torch.save(psp_model.state_dict(), "pspnet_custom_adv.pth")
print("Модель успешно сохранена!")

# Визуализация результатов
visualize_predictions(psp_model, test_dataset_aug, device, num_samples=5)