In [None]:
import itertools
import os
import warnings

import cv2
import numpy as np
import segmentation_models_pytorch as smp
import torch
from sklearn.exceptions import UndefinedMetricWarning
from torch.optim import Adam
from torch.utils.data._utils.collate import default_collate
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm

from coco_classes import (
    kidneys_base_classes,
    kidneys_pat_out_classes,
)
from coco_dataloaders import SINUSITE_COCODataLoader
from metrics import DetectionMetrics
from transforms import SegTransform
from utils import ExperimentSetup, iou_metric, save_best_metrics_to_csv, set_seed

set_seed(64)


def get_direct_subdirectories(directory):
    subdirectories = [
        d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))
    ]
    return [os.path.join(directory, subdir) for subdir in subdirectories]


def custom_collate_fn(batch):
    images = []
    masks = []
    for item in batch:
        image, mask = item["images"], item["masks"]
        images.append(image)
        masks.append(mask)

    collated_images = default_collate(images)
    collated_masks = default_collate(masks)

    return {"images": collated_images, "masks": collated_masks}

def min_max_normalize(tensor):
    min_val = tensor.min()
    max_val = tensor.max()
    normalized_tensor = (tensor - min_val) / (max_val - min_val + 1e-8)
    return normalized_tensor

def train_model(
    model,
    optimizer,
    criterion,
    lr_sched,
    num_epochs,
    train_loader,
    val_loader,
    device,
    num_classes,
    experiment_name,
    all_class_weights,
    alpha,
    use_opt_pixel_weight,
    num_cyclic_steps,  # количество циклических шагов на валидации
    max_n=3,
    max_k=3,
    use_augmentation=False,
    loss_type="weak",
):

    writer = SummaryWriter(log_dir=f"runs_kidneys/{experiment_name}_logs")
    metrics_calculator = DetectionMetrics(mode="ML", num_classes=num_classes)

    class_names_dict = {
        class_info["id"]: class_info["name"]
        # sinusite_pat_classes_3 или kidneys_pat_out_classes
        for class_info in kidneys_pat_out_classes
    }
    print("class_names_dict", class_names_dict)

    classes = list(class_names_dict.keys())
    weight_opt = Weight_opt_class(criterion, classes, None)

    # print("device", device)
    model = model.to(device)

    best_loss = 100

    global_stats = {
        "global_loss_sum": torch.tensor(0.0, dtype=torch.double),
        "global_loss_numel": torch.tensor(0.0, dtype=torch.double),
    }

    if alpha is not None:
        alpha_no_fon = np.array([arr[1:] for arr in alpha])
        # alpha_no_fon = np.array(alpha[1:], dtype=np.float16) было так вместо верхней строки
        alpha_no_fon = torch.tensor(alpha_no_fon).to(device)
    else:
        alpha_no_fon = None

    if use_augmentation:
        seg_transform = SegTransform()

    num_batches = len(train_loader)

    for epoch in range(num_epochs):
        # убрал
        # torch.cuda.empty_cache() # должно освобождать память
        # print("gpu_usage()", gpu_usage())

        model.train()
        train_loss_sum = 0.0
        val_loss_sum = 0.0  # сюда переместил
        # train_iou_sum = 0.0
        # val_iou_sum = 0.0
        train_iou_sum = torch.zeros(num_classes)
        val_iou_sum = torch.zeros(num_classes)

        n = 0
        ############################
        with tqdm(
            total=len(train_loader),
            desc=f"Epoch {epoch + 1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for batch_idx, train_batch in enumerate(train_loader):
                optimizer.zero_grad()
                
               

                images = train_batch["images"].to(device)
                masks = train_batch["masks"][:, 1:, :, :].to(device)

 

                if use_augmentation:
                    images, masks = seg_transform.apply_transform(images, masks)

                if all_class_weights is not None:
                    all_weights_no_fon = [x[1:] for x in all_class_weights]
                else:
                    all_weights_no_fon = None

                
                outputs = model(images)
                outputs = torch.sigmoid(outputs)

                if loss_type == "weak" or loss_type == "strong":
                    loss = criterion(outputs, masks, all_class_weights, alpha_no_fon)
                elif loss_type == "focus":
                    loss, global_loss_sum, global_loss_numel = criterion(
                        outputs,
                        masks,
                        global_loss_sum=global_stats["global_loss_sum"],
                        global_loss_numel=global_stats["global_loss_numel"],
                        train_mode=True,
                        mode="ML",
                    )


                    global_stats["global_loss_sum"] = global_loss_sum #/ n
                    global_stats["global_loss_numel"] = global_loss_numel #/ n

                    print("global_stats", global_stats)
                    
                   
                elif loss_type == "bce":
                    loss = criterion(outputs, masks)

                loss.backward()
                optimizer.step()
                train_loss_sum += loss.item()
                train_iou_batch = iou_metric(outputs, masks, num_classes)
                train_iou_sum += train_iou_batch

                # для трейна метрики тоже посчитаю
                metrics_calculator.update_counter(
                    masks,
                    outputs,  # outputs не шум
                )  # , advanced_metrics=True)

                # скользящее среднее
                n += 1
                pbar.set_postfix(loss=train_loss_sum / n)
                pbar.update(1)

                # оптимизация весов

        train_loss_avg = train_loss_sum / len(train_loader)

        # среднее по всем батчам
        # train_iou_avg = train_iou_sum / len(train_loader)
        train_iou_avg = train_iou_sum / len(train_loader)

        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
            train_metrics = metrics_calculator.calc_metrics()

        for key, value in train_metrics.items():
            if isinstance(value, torch.Tensor):
                if len(value.size()) > 0:  # Проверяем, что тензор не пустой
                    # средняя метрика
                    writer.add_scalar(f"Train/Mean/{key}", value.mean().item(), epoch)
                    for i, val in enumerate(value):
                        class_name = class_names_dict[i + 1]
                        # writer.add_scalar(f"Train/{key}/Class_{i}", val.item(), epoch)
                        writer.add_scalar(
                            f"Train/{key}/{class_name}", val.item(), epoch
                        )
                else:
                    writer.add_scalar(f"Train/{key}", value.item(), epoch)

        writer.add_scalar("Learning Rate", optimizer.param_groups[0]["lr"], epoch)

        print("alpha_no_fon", alpha_no_fon)
        if alpha_no_fon is not None:
            print(f"\nclass: {class_names_dict[1]}, pixel_pos_weights {alpha_no_fon[0][0]}")
            print(f"class: {class_names_dict[2]}, pixel_pos_weights {alpha_no_fon[0][1]}")
            print(f"class: {class_names_dict[3]}, pixel_pos_weights {alpha_no_fon[0][2]}\n")
            
            print(f"class: {class_names_dict[1]}, pixel_neg_weights {alpha_no_fon[1][0]}")
            print(f"class: {class_names_dict[2]}, pixel_neg_weights {alpha_no_fon[1][1]}")
            print(f"class: {class_names_dict[3]}, pixel_neg_weights {alpha_no_fon[1][2]}\n")
            
            print(f"class: {class_names_dict[1]}, pixel_class_weights {alpha_no_fon[2][0]}")
            print(f"class: {class_names_dict[2]}, pixel_class_weights {alpha_no_fon[2][1]}")
            print(f"class: {class_names_dict[3]}, pixel_class_weights {alpha_no_fon[2][2]}\n")
            
        
        print("class_names_dict", class_names_dict)
        
        # оптимизация пиксельных весов
        if use_opt_pixel_weight:
            alpha_no_fon = weight_opt.opt_pixel_weight(train_metrics, alpha_no_fon)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss_avg}, Train IoU: {train_iou_avg}")
        ###############################################################
        # Валидация
        model.eval()
        with torch.no_grad():
            for val_batch in val_loader:
                images_val = val_batch["images"].to(device)
                masks_val = (
                    val_batch["masks"][:, 1:].to(device)  # .float()
                )  
                outputs_val = model(images_val)
                outputs_val = torch.sigmoid(outputs_val)

                if loss_type == "weak" or loss_type == "strong":
                    val_loss_sum += criterion(outputs_val, masks_val, None, None).item()
                 
                elif loss_type == "focus":
                    val_loss, _, _ = criterion(
                        outputs_val,
                        masks_val,
                        global_loss_sum=None,
                        global_loss_numel=None,
                        train_mode=False,
                        mode="ML",
                    )
                    val_loss_sum += val_loss.item()

                elif loss_type == "bce":
                    val_loss_sum += criterion(outputs_val, masks_val).item()

                # было так но я добавил focus loss выше
                # val_loss_sum += criterion(outputs_val, masks_val, None, None).item()

                val_iou_batch = iou_metric(outputs_val, masks_val, num_classes)

                val_iou_sum += val_iou_batch

                metrics_calculator.update_counter(
                    masks_val,
                    outputs_val,  # не шум
                    # corrected_masks_val,  # шум
                )  # advanced_metrics=True)

                # val_image_visualizer.visualize(images_val, masks_val, outputs_val, class_names_dict, colors, epoch)

            val_loss_avg = val_loss_sum / len(val_loader)

            # val_iou_avg = val_iou_sum / len(val_loader)
            val_iou_avg = val_iou_sum / len(val_loader)

        print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss_avg}, Val Loss: {val_loss_avg},  Val IoU: {val_iou_avg}")

        if lr_sched is not None:
            lr_sched.step()  # когда мы делаем эту команду он залезает в optimizer и изменяет lr умножая его на 0.5

        # обработаю исключение но не знаю хорошая идея или нет
        with warnings.catch_warnings():
            warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
            val_metrics = metrics_calculator.calc_metrics()

        for key, value in val_metrics.items():
            if isinstance(value, torch.Tensor):
                if len(value.size()) > 0:
                    # добавил среднюю метрику по классам
                    writer.add_scalar(f"Val/Mean/{key}", value.mean().item(), epoch)
                    for i, val in enumerate(value):
                        class_name = class_names_dict[i + 1]
                        # writer.add_scalar(f"Val/{key}/Class_{i}", val.item(), epoch)
                        writer.add_scalar(f"Val/{key}/{class_name}", val.item(), epoch)
                else:
                    writer.add_scalar(f"Val/{key}", value.item(), epoch)

        for class_idx, iou_value in enumerate(train_iou_avg):
            class_name = class_names_dict[
                class_idx + 1
            ]  # в out classes индексы с единицы
            writer.add_scalar(f"My_train_IoU/{class_name}", iou_value, epoch)

        for class_idx, iou_value in enumerate(val_iou_avg):
            class_name = class_names_dict[class_idx + 1]
            writer.add_scalar(f"My_val_IoU/{class_name}", iou_value, epoch)

        writer.add_scalar("Loss/train", train_loss_avg, epoch)
        writer.add_scalar("Loss/validation", val_loss_avg, epoch)

        # тут сохранение лучшей модели
        if val_loss_avg < best_loss:
            best_loss = val_loss_avg

            # Сохранение метрик в CSV
            best_metrics = {
                "experiment": experiment_name.split("_")[0],
                "epoch": epoch,
                "train_loss": train_loss_avg,
                "val_loss": val_loss_avg,
                "val_metrics": {
                    "IOU": val_metrics["IOU"],
                    "F1": val_metrics["F1"],
                    "area_probs_F1": val_metrics["area_probs_F1"],
                },
            }

            # best_model_path = "sinusite_best_models"
            best_model_path = "kidneys_best_models"
            if not os.path.exists(best_model_path):
                os.makedirs(best_model_path)

            torch.save(
                model.state_dict(),
                f"{best_model_path}/best_{experiment_name}_model.pth",
            )

            csv_file = f"{best_model_path}/best_metrics.csv"
            save_best_metrics_to_csv(best_metrics, csv_file)

    last_metrics = {
        "experiment": experiment_name.split("_")[0],
        "epoch": epoch,
        "train_loss": train_loss_avg,
        "val_loss": val_loss_avg,
        "val_metrics": {
            "IOU": val_metrics["IOU"],
            "F1": val_metrics["F1"],
            "area_probs_F1": val_metrics["area_probs_F1"],
        },
    }

    # last_model_path = "sinusite_last_models"
    last_model_path = "kidneys_last_models"
    if not os.path.exists(last_model_path):
        os.makedirs(last_model_path)

    torch.save(
        model.state_dict(),
        f"{last_model_path}/last_{experiment_name}_model.pth",
    )

    last_csv_file = f"{last_model_path}/last_metrics.csv"
    save_best_metrics_to_csv(last_metrics, last_csv_file)

    writer.close()


class Weight_opt_class:
    def __init__(self, loss, classes, b=None):
        self.b = b
        self.loss = loss
        self.loss_class = loss
        self.classes = classes

    # оптимизация новых метрик
    def opt_pixel_weight(self, metrics, pixel_all_class_weights=None):
        recall = metrics["advanced_recall"]
        precession = metrics["advanced_precision"]  # раньше precession было
        F1Score = metrics["advanced_F1"]

        b = self.b

        if b is None:
            b = 1

        for image_class, cl_name in enumerate(self.classes):
            neg_coef = 1
            pos_coef = 1

            if recall[image_class].item() != 0 and precession[image_class].item() != 0:
                print("recall и precision != 0")
                print("recall[image_class].item()", recall[image_class].item())
                print("precession[image_class].item()", precession[image_class].item())

                neg_coef = (
                    (1 / b) * recall[image_class].item() / F1Score[image_class].item()
                )
                pos_coef = (
                    (b) * precession[image_class].item() / F1Score[image_class].item()
                )
                print("neg_coef", neg_coef)
                print("pos_coef", pos_coef)

                xsd = recall[image_class].item() / precession[image_class].item()
                print("xsd", xsd)
                if xsd > 0.9 and xsd < 1.1:
                    neg_coef = 1
                    pos_coef = 1
                class_coef = pos_coef
                print("вот после изменений")
                print("neg_coef", neg_coef)
                print("pos_coef", pos_coef)
                print("class_coef", class_coef)

            else:
                print("recall или precision == 0")
                print("recall[image_class].item()", recall[image_class].item())
                print("precession[image_class].item()", precession[image_class].item())

                pos_coef = 2.0
                class_coef = 2.0
                neg_coef = 0.5
                print("neg_coef", neg_coef)
                print("pos_coef", pos_coef)
                print("class_coef", class_coef)

            if pixel_all_class_weights is not None:
                pixel_all_class_weights[0][image_class] *= pos_coef
                pixel_all_class_weights[1][image_class] *= neg_coef
                pixel_all_class_weights[2][image_class] *= class_coef

        return pixel_all_class_weights


if __name__ == "__main__":


    batch_size = 24
    num_classes = 3


    params = {
        "json_file_path": "/home/imran-nasyrov/json_pochki",
        "delete_list": [],
        "base_classes": kidneys_base_classes,
        "out_classes": kidneys_pat_out_classes,
        "dataloader": True,
        "resize": (512, 512),
        "recalculate": False,
        "delete_null": False,
    }

    coco_dataloader = SINUSITE_COCODataLoader(params)

    (
        train_loader,
        val_loader,
        total_train,
        pixel_total_train,
        list_of_name_out_classes,
    ) = coco_dataloader.make_dataloaders(batch_size=batch_size, train_val_ratio=0.8)



    print("total_train", total_train)
    print("len total_train", len(total_train))
    print("list_of_name_out_classes", list_of_name_out_classes)
    print("pixel_TotalTrain", pixel_total_train)
    print("len val_loader", len(val_loader))
    print("len train_loader", len(train_loader))

    device = torch.device("cpu")
    print(device)
    print(torch.cuda.get_device_name(torch.cuda.current_device()))

    model = smp.Linknet(
        encoder_name="efficientnet-b7",
        encoder_weights="imagenet",
        in_channels=1,  # +num_classes для диффузии
        classes=num_classes,
    )



    learning_rate = 3e-4
    num_epochs = 120

    optimizer = Adam(model.parameters(), lr=learning_rate)
    lr_sched = None

    use_class_weight = False
    use_pixel_weight = False
    use_pixel_opt = False
    power = "1.7.2_kidneys_weak"  # focus или weak

    loss_type = power.split("_")[-1]
    print("loss_type", loss_type)

    exp_setup = ExperimentSetup(
        train_loader, total_train, pixel_total_train, batch_size, num_classes
    )

    (
        all_class_weights,
        pixel_all_class_weights,
        experiment_name,
        criterion,
    ) = exp_setup.setup_experiment(
        use_class_weight, use_pixel_weight, use_pixel_opt, power
    )

    print("experiment_name", experiment_name)
    print("criterion", criterion)

    train_model(
        model,
        optimizer,
        criterion,
        lr_sched,
        num_epochs,
        train_loader,
        val_loader,
        device,
        num_classes,
        experiment_name,
        all_class_weights=all_class_weights,
        alpha=pixel_all_class_weights,
        use_opt_pixel_weight=use_pixel_opt,
        num_cyclic_steps=0,
        max_n=3,
        max_k=3,
        use_augmentation=False,
        loss_type=loss_type,
    )

    model_weight = f"sinusite_best_models/best_{experiment_name}_model.pth"

    val_predict_path = f"diff_predict_sinusite/predict_{experiment_name}/val"
    train_predict_path = f"diff_predict_sinusite/predict_{experiment_name}/train"

    limited_train_loader = itertools.islice(train_loader, 6)
    limited_val_loader = itertools.islice(val_loader, 6)



subdir train /home/imran-nasyrov/json_pochki/1_194518_zno_1_194518_zno_AF_frontal
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
subdir train /home/imran-nasyrov/json_pochki/3_2327723_cyst_3_2327723_cyst_AF_sagital
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
sct_coco._total_train [512.   0.   0.  42.]
subdir train /home/imran-nasyrov/json_pochki/3_2327723_cyst_3_2327723_cyst_NF_frontal
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
sct_coco._total_train [512.   0.   0.  42.]
subdir train /home/imran-nasyrov/json_pochki/1_194518_zno_1_194518_zno_VF_frontal
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
sct_coco._total_train [512.  28.   0.   0.]
subdir train /home/imran-nasyrov/json_pochki/5_2229997_norma_5_2229997_norma_AF_axial
loading annotations into memory...
Done (t=0.00s)
creating index...
index created!
sct_coco._total_train [460.   0

Epoch 1/120:   1%|▉                                                                                                                                           | 40/6068 [39:35<97:35:58, 58.29s/batch, loss=0.0988]