RootDataset

In [None]:
# !pip install einops -U
# !pip install -U ipywidgets==8.0.0
# !pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118 -U
# !pip install -U ray[train]==2.8.1

In [1]:
import torch

base_TBAD_csv_path = r"D:/dataset/med/imageTBAD/dataframe.csv"

In [2]:
import torch.nn as nn
import torch
from torchinfo import summary


class ResidualConv(nn.Module):
    def __init__(self, input_dim, output_dim, stride, padding):
        super(ResidualConv, self).__init__()

        self.conv_block = nn.Sequential(
            nn.BatchNorm2d(input_dim),
            nn.ReLU(),
            nn.Conv2d(
                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
            ),
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
            nn.BatchNorm2d(output_dim),
        )

    def forward(self, x):
        return self.conv_block(x) + self.conv_skip(x)


class Upsample(nn.Module):
    def __init__(self, input_dim, output_dim, kernel, stride):
        super(Upsample, self).__init__()

        self.upsample = nn.ConvTranspose2d(
            input_dim, output_dim, kernel_size=kernel, stride=stride
        )

    def forward(self, x):
        return self.upsample(x)


class Squeeze_Excite_Block(nn.Module):
    def __init__(self, channel, reduction=16):
        super(Squeeze_Excite_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


class ASPP(nn.Module):
    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(out_dims),
        )

        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()


class Upsample_(nn.Module):
    def __init__(self, scale=2):
        super(Upsample_, self).__init__()

        self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)

    def forward(self, x):
        return self.upsample(x)


class AttentionBlock(nn.Module):
    def __init__(self, input_encoder, input_decoder, output_dim):
        super(AttentionBlock, self).__init__()

        self.conv_encoder = nn.Sequential(
            nn.BatchNorm2d(input_encoder),
            nn.ReLU(),
            nn.Conv2d(input_encoder, output_dim, 3, padding=1),
            nn.MaxPool2d(2, 2),
        )

        self.conv_decoder = nn.Sequential(
            nn.BatchNorm2d(input_decoder),
            nn.ReLU(),
            nn.Conv2d(input_decoder, output_dim, 3, padding=1),
        )

        self.conv_attn = nn.Sequential(
            nn.BatchNorm2d(output_dim),
            nn.ReLU(),
            nn.Conv2d(output_dim, 1, 1),
        )

    def forward(self, x1, x2):
        out = self.conv_encoder(x1) + self.conv_decoder(x2)
        out = self.conv_attn(out)
        return out * x2


class ResUnetPlusPlus(nn.Module):
    def __init__(self, in_channel: int = 3, out_channel: int = 1, filters: list[int] = [32, 64, 128, 256, 512]):
        super(ResUnetPlusPlus, self).__init__()

        self.input_layer = nn.Sequential(
            nn.Conv2d(in_channel, filters[0], kernel_size=3, padding=1),
            nn.BatchNorm2d(filters[0]),
            nn.ReLU(),
            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
        )
        self.input_skip = nn.Sequential(
            nn.Conv2d(in_channel, filters[0], kernel_size=3, padding=1)
        )

        self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])

        self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)

        self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])

        self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)

        self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])

        self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)

        self.aspp_bridge = ASPP(filters[3], filters[4])

        self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
        self.upsample1 = Upsample_(2)
        self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)

        self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
        self.upsample2 = Upsample_(2)
        self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)

        self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
        self.upsample3 = Upsample_(2)
        self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)

        self.aspp_out = ASPP(filters[1], filters[0])

        self.output_layer = nn.Sequential(nn.Conv2d(filters[0], out_channel, 1),
                                          nn.Sigmoid())

    def forward(self, x):
        x1 = self.input_layer(x) + self.input_skip(x)

        x2 = self.squeeze_excite1(x1)
        x2 = self.residual_conv1(x2)

        x3 = self.squeeze_excite2(x2)
        x3 = self.residual_conv2(x3)

        x4 = self.squeeze_excite3(x3)
        x4 = self.residual_conv3(x4)

        x5 = self.aspp_bridge(x4)

        x6 = self.attn1(x3, x5)
        x6 = self.upsample1(x6)
        x6 = torch.cat([x6, x3], dim=1)
        x6 = self.up_residual_conv1(x6)

        x7 = self.attn2(x2, x6)
        x7 = self.upsample2(x7)
        x7 = torch.cat([x7, x2], dim=1)
        x7 = self.up_residual_conv2(x7)

        x8 = self.attn3(x1, x7)
        x8 = self.upsample3(x8)
        x8 = torch.cat([x8, x1], dim=1)
        x8 = self.up_residual_conv3(x8)

        x9 = self.aspp_out(x8)
        out = self.output_layer(x9)

        return out


model = ResUnetPlusPlus(3).cuda()
summary(model, input_size=(1, 3, 512, 512))


Layer (type:depth-idx)                   Output Shape              Param #
ResUnetPlusPlus                          [1, 1, 512, 512]          --
├─Sequential: 1-1                        [1, 32, 512, 512]         --
│    └─Conv2d: 2-1                       [1, 32, 512, 512]         896
│    └─BatchNorm2d: 2-2                  [1, 32, 512, 512]         64
│    └─ReLU: 2-3                         [1, 32, 512, 512]         --
│    └─Conv2d: 2-4                       [1, 32, 512, 512]         9,248
├─Sequential: 1-2                        [1, 32, 512, 512]         --
│    └─Conv2d: 2-5                       [1, 32, 512, 512]         896
├─Squeeze_Excite_Block: 1-3              [1, 32, 512, 512]         --
│    └─AdaptiveAvgPool2d: 2-6            [1, 32, 1, 1]             --
│    └─Sequential: 2-7                   [1, 32]                   --
│    │    └─Linear: 3-1                  [1, 2]                    64
│    │    └─ReLU: 3-2                    [1, 2]                    --
│    │    

Config

In [3]:
import os.path

from torchvision.transforms import v2
from torch.utils.data import Dataset
import pandas as pd

import nibabel as nib
import numpy as np
from rich.console import Console

console = Console()


class TBADaRoottaset(Dataset):
    def __init__(self, csv_path: str = base_TBAD_csv_path):
        self.df: pd.DataFrame = pd.read_csv(csv_path, delimiter=',')

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        # console.log({"image": row['img'], 'mask': row['mask']})
        img_path = os.path.join(row['img'])
        mask_path = os.path.join(row['mask'])
        return {"image": img_path, 'mask': mask_path}


class DatasetofEachNII(Dataset):
    def __init__(self, nii_image_path: str, nii_label_path: str):
        console.print(f'[bold green]Loading {nii_image_path} and {nii_label_path}[/bold green]')
        img = nib.load(''.join(nii_image_path))
        self.img_data = img.get_fdata()

        mask = nib.load(''.join(nii_label_path))
        self.mask_data = mask.get_fdata()
        if self.mask_data.ndim < self.img_data.ndim or self.mask_data.shape[2] < self.img_data.shape[2]:
            # mask_data = np.zeros_like(img_data) + mask_data
            self.mask_data = np.pad(self.mask_data,
                                    (0, self.img_data.shape[2] - self.mask_data.shape[2]),
                                    'constant',
                                    constant_values=0)
        self.trans_train_data = v2.Compose([
            v2.ToImage(),
            v2.ToDtype(torch.float16, scale=True),
            v2.Normalize([1], [10.0]),
            # v2.Lambda(lambda x: x / 255),
            #             v2.Resize((280, 280)),
            # v2.Resize((384, 384)),
            v2.ConvertImageDtype(torch.float),
        ])
        self.trans_label = v2.Compose([
            v2.ToImage(),
            # v2.
            # v2.Resize((384, 384)),
            #             v2.Resize((280, 280)),
            v2.ToDtype(torch.long)
        ])

    def __len__(self):
        return self.img_data.shape[2]

    def __getitem__(self, idx):
        image = torch.as_tensor(self.img_data[:, :, idx])
        mask_d = torch.as_tensor(self.mask_data[:, :, idx])
        return {'image': self.trans_train_data(image), 'label': self.trans_label(mask_d)}


In [4]:
import torch
from torch import nn
import torch.nn.functional as F
from einops import reduce

ALPHA = 0.8
GAMMA = 2
pp = 2
smoo = 1
epsilon = 1e-4


class SoftDiceFocalLoss(nn.Module):
    """
        SoftDiceFocalLoss类用于计算Soft Dice Focal损失。

        参数:
        weight: 一个可选的权重张量，用于在计算损失时对每个类别进行加权。
        size_average: 一个布尔值，如果为True，则损失将通过平均所有元素来计算；如果为False，则损失将通过对所有元素求和来计算。

        方法:
        forward(inputs, targets, alpha=ALPHA, gamma=GAMMA): 计算Soft Dice Focal损失。
        inputs: 输入张量，形状为[b, 4, h, w]。
        targets: 目标张量，形状为[b, 1, h, w]，其中每个元素的值在0~3之间。
        alpha: Focal损失的alpha参数，默认值为0.8。
        gamma: Focal损失的gamma参数，默认值为2。
        """

    def __init__(self):
        super(SoftDiceFocalLoss, self).__init__()
        self.p = pp
        self.smooth = smoo
        self.cross = nn.CrossEntropyLoss()

    def forward(self, inputs, targets, alpha=ALPHA, gamma=GAMMA):
        # inputs: b 3 h w
        # targets: b 1 h w   (0~3(int) in 1 dim)
        # targets = targets.squeeze(1)
        # print(
        #     f'Inputs: {inputs.shape}; Targets: {targets.shape};targets max: {torch.max(targets)}; targets min: {torch.min(targets)}'))
        targets = torch.squeeze(targets, 1)
        targets = torch.where(targets.eq(3), 0, targets)

        # fig, axes = plt.subplots(1, 2)
        # axes[0].imshow(torch.where((targets[0].cpu() == 1), 1, 0), cmap='gray')
        # axes[1].imshow(torch.where((targets[0].cpu() == 2), 1, 0), cmap='gray')

        plt.show()
        # print(f'Targets: {targets.shape}; Inputs: {inputs.shape}')
        CE = self.cross(inputs, targets)
        CE_EXP = torch.exp(-CE)
        focal_loss = alpha * (1 - CE_EXP) ** gamma * CE
        # # flatten label and prediction tensors
        # inputs = (inputs[:, 1, :, :] + inputs[:, 2, :, :] + inputs[:, 3, :, :]).reshape(-1)

        # targets: torch.Tensor = torch.unsqueeze(targets, 1)
        # Iterate over each unique value

        target_dice = torch.stack([torch.where(targets.eq(value), 1, 0) for value in range(3)], dim=1)

        # dice_loss = self.Dice(inputs, target_dice)
        dice_loss = 0
        for i in range(inputs.shape[1]):  # iterate over each channel
            input_channel = inputs[:, i, :, :]
            target_channel = target_dice[:, i, :, :]
            dice_loss = self.Dice(input_channel, target_channel)

        total_loss = dice_loss + focal_loss

        # + dice_loss

        return total_loss

    def GenerilizeDiceLoss(self, dice_loss, input_channel, target_channel):
        wei = torch.sum(target_channel, axis=[0, 2, 3])  # (n_class,)
        wei = 1 / (wei ** 2 + epsilon)
        intersection = torch.sum(wei * torch.sum(input_channel * target_channel, axis=[0, 2, 3]))
        union = torch.sum(wei * torch.sum(input_channel + target_channel, axis=[0, 2, 3]))
        dice_loss += 1 - (2. * intersection) / (union + epsilon)
        return dice_loss

    def Dice(self, inputs, target_dice):
        probs = torch.sigmoid(inputs)
        target = torch.sigmoid(target_dice)
        # print(f'Probs: {probs.shape}; Target: {target.shape}')
        # b 3 256 256 - b 256 256
        probs = torch.reshape(probs, (-1,))
        target = torch.reshape(target, (-1,))
        # print(f'Probs: {probs.shape}; Target: {target.shape}')
        numer = (probs * target).sum()
        denor = (probs.pow(self.p) + target.pow(self.p)).sum()
        dice_loss = 1. - (2 * numer + self.smooth) / (denor + self.smooth)
        return dice_loss


DDP

In [21]:

import traceback
import torch
from torchvision.transforms import v2 as T

from torch.utils.data.distributed import DistributedSampler

params = {
    'epochs': 10,
    'batch_size': 1,
    'amp': True,
    'shuffle': True,
    'in_channel': 1,
    'out_channel': 3,
    'T_0': 1,
    'T_mult': 2,
    'proportion': 0.9,
    'cos': True,
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transformation = T.Compose([
    T.ToDtype(torch.float32),
    T.ToPILImage(),
])
torch.manual_seed(3407)


def train_func(rank, *args):
    try:
        # model = ResUnetPlusPlus(in_channel=params['in_channel'], out_channel=params['out_channel'])
        UnetModel = ResUnetPlusPlus(in_channel=params['in_channel'], out_channel=params['out_channel']).to(rank)
        UnetModel = DDP(UnetModel, device_ids=[rank])
        setup(rank, args[0])
        # model = ray.train.torch.prepare_model(model)
        # loss_fn = SoftDiceLoss()
        loss_fn = SoftDiceFocalLoss()
        # loss_fn = nn.MSELoss()
        opti = torch.optim.RAdam(UnetModel.parameters(), lr=0.003)
        schd_lr = None
        if params['cos']:
            schd_lr = CosineAnnealingWarmRestarts(opti, T_0=params['T_0'], T_mult=params['T_mult'])

        #     fig, axs = plt.subplots(2, 3)
        scaler = torch.cuda.amp.GradScaler(init_scale=8192)
        for epoch in tqdm(range(params['epochs']), leave=False):
            rootset = TBADaRoottaset()
            rootloader = DataLoader(rootset, batch_size=1, shuffle=params['shuffle'], drop_last=True)

            for data in tqdm(rootloader, leave=False):
                nii_set = DatasetofEachNII(nii_image_path=data['image'],
                                           nii_label_path=data['mask'])
                train_set, test_set = torch.utils.data.random_split(nii_set, [int(len(nii_set) * params['proportion']),
                                                                              len(nii_set) - int(
                                                                                  len(nii_set) * params['proportion'])])
                ddp_train_sam = DistributedSampler(train_set)
                ddp_test_sam = DistributedSampler(test_set)

                train_loader = DataLoader(train_set, batch_size=params['batch_size'],
                                          shuffle=True, drop_last=True,
                                          collate_fn=col_fn, sampler=ddp_train_sam)
                test_loader = DataLoader(test_set, batch_size=params['batch_size'],
                                         shuffle=True, drop_last=True,
                                         collate_fn=col_fn, sampler=ddp_test_sam)
                # ray_train_loader = ray.train.torch.prepare_data_loader(train_dataloader)
                # ray_test_loader = ray.train.torch.prepare_data_loader(test_dataloader)
                train_loss = 0
                for niidata in tqdm(train_loader, leave=False):
                    img = niidata['images'].to(rank)
                    label = niidata['labels'].to(rank)
                    if params['amp']:
                        with torch.autocast(device_type=rank):
                            opti.zero_grad()
                            UnetModel.train()
                            pred = UnetModel(img)
                            display_dynamicly(img, label, pred)
                            loss_now = loss_fn(pred, label)
                        train_loss += loss_now.item()
                        scaler.scale(loss_now).backward()
                        scaler.step(opti)
                        if params['cos']:
                            schd_lr.step()
                        scaler.update()
                    else:
                        opti.zero_grad()
                        UnetModel.train()
                        pred = UnetModel(img)
                        display_dynamicly(img, label, pred)
                        loss_now = loss_fn(pred, label)
                        train_loss += loss_now.item()
                        loss_now.backward()
                        opti.step()
                        if params['cos']:
                            schd_lr.step()
                    print(f'Epoch: {epoch} Loss: {loss_now * params["batch_size"]}')

                test_loss = 0
                for niidatatest in tqdm(test_loader, leave=False):
                    with torch.no_grad():
                        img = niidatatest['images'].to(rank)
                        label = niidatatest['labels'].to(rank)
                        UnetModel.eval()
                        pred = UnetModel(img)
                        loss = loss_fn(pred, label)
                        test_loss += loss.item()
                train_loss /= (len(train_loader) / params['batch_size'])
                test_loss /= (len(test_loader) / params['batch_size'])
                metrics = {"train_loss": train_loss, 'test_loss': test_loss, "epoch": epoch}
                #             with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
                #                 torch.save(
                #                     model.state_dict(),
                #                     os.path.join(temp_checkpoint_dir, "model.pt")
                #                 )
                #                 ray.train.report(
                #                     metrics,
                #                     checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
                #                 )
                #             if ray.train.get_context().get_world_rank() == 0:
                print(metrics)
                # del test_dataloader
                # del train_dataloader
    except Exception as e:
        print(f"An error occurred in process {rank}: {e}")
        traceback.print_exc()
    finally:
        cleanup()


def display_dynamicly(img, label, pred):
    fig, axs = plt.subplots(2, 3)
    random_index = random.randint(0, pred.shape[0] - 1)
    axs[0, 0].imshow(transformation(pred[random_index][0]), cmap='gray')
    axs[0, 0].title.set_text('Prediction: 0')
    axs[0, 1].imshow(transformation(pred[random_index][1]), cmap='gray')
    axs[0, 1].title.set_text('Prediction: 1')
    axs[0, 2].imshow(transformation(pred[random_index][2]), cmap='gray')
    axs[0, 2].title.set_text('Prediction: 2')
    axs[1, 0].imshow(transformation(label[random_index]), cmap='gray')
    axs[1, 0].title.set_text('Label')
    axs[1, 1].imshow(transformation(img[random_index]), cmap='gray')
    axs[1, 1].title.set_text('Image')
    plt.show()

    # plt.colorbar()
    # Pause for a short period, allowing the plot to update
    plt.pause(0.1)
    #     # Clear the current axes
    axs[0, 0].cla()
    axs[0, 1].cla()
    axs[0, 2].cla()

    axs[1, 0].cla()
    axs[1, 1].cla()
    axs[1, 2].cla()
    plt.clf()


def col_fn(images):
    # {"image": png_img, "label": png_label}
    # image: 1*512* 512
    inputs: torch.Tensor = torch.stack([image['image'] for image in images])
    labels: torch.Tensor = torch.stack([image['label'] for image in images])
    return {"images": inputs, "labels": labels}


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)
def cleanup():
    dist.destroy_process_group()


def run_demo(world_size):
    mp.spawn(train_func,
             args=(world_size,),
             nprocs=world_size,
             join=True)


run_demo(1)


ProcessExitedException: process 0 terminated with exit code 1

In [20]:
import os

import torch.distributed as dist
import torch.multiprocessing as mp




ProcessExitedException: process 0 terminated with exit code 1

In [ ]:
import tempfile
from typing import Callable, Union

import dill
import ray
import torch
from PIL._imaging import draw
from einops import rearrange
from torch import nn
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts

from torch.utils.data import DataLoader
from tqdm import tqdm
from rich.console import Console
from matplotlib import pyplot as plt
from torchvision.transforms import v2 as T

import torch.multiprocessing as mp
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

from ray.train.torch import TorchTrainer
from ray.train import ScalingConfig
import cloudpickle
import random

params = {
    'epochs': 10,
    'batch_size': 1,
    'amp': True,
    'shuffle': True,
    'in_channel': 1,
    'out_channel': 3,
    'T_0': 1,
    'T_mult': 2,
    'proportion': 0.9,
    'cos': True,
}
device = 'cuda' if torch.cuda.is_available() else 'cpu'
transformation = T.Compose([
    T.ToDtype(torch.float32),
    T.ToPILImage(),
])
torch.manual_seed(3407)


def train_func(rank, world_size):
    # model = ResUnetPlusPlus(in_channel=params['in_channel'], out_channel=params['out_channel'])
    model = ResUnetPlusPlus(in_channel=params['in_channel'], out_channel=params['out_channel']).to(rank)
    # model = DDP(model, device_ids=[rank])
    # setup(rank, world_size)
    # model = ray.train.torch.prepare_model(model)
    # loss_fn = SoftDiceLoss()
    loss_fn = SoftDiceFocalLoss()
    # loss_fn = nn.MSELoss()
    opti = torch.optim.RAdam(model.parameters(), lr=0.003)
    schd_lr = None
    if params['cos']:
        schd_lr = CosineAnnealingWarmRestarts(opti, T_0=params['T_0'], T_mult=params['T_mult'])

    #     fig, axs = plt.subplots(2, 3)
    scaler = torch.cuda.amp.GradScaler(init_scale=8192)
    for epoch in tqdm(range(params['epochs']), leave=False):
        rootset = TBADaRoottaset()
        rootloader = DataLoader(rootset, batch_size=1, shuffle=params['shuffle'], drop_last=True)

        for data in tqdm(rootloader, leave=False):
            nii_set = DatasetofEachNII(nii_image_path=data['image'],
                                       nii_label_path=data['mask'])
            train_set, test_set = torch.utils.data.random_split(nii_set, [int(len(nii_set) * params['proportion']),
                                                                          len(nii_set) - int(
                                                                              len(nii_set) * params['proportion'])])

            train_loader = DataLoader(train_set, batch_size=params['batch_size'],
                                      shuffle=True, drop_last=True,
                                      collate_fn=col_fn)
            test_loader = DataLoader(test_set, batch_size=params['batch_size'],
                                     shuffle=True, drop_last=True,
                                     collate_fn=col_fn)
            # ray_train_loader = ray.train.torch.prepare_data_loader(train_dataloader)
            # ray_test_loader = ray.train.torch.prepare_data_loader(test_dataloader)
            train_loss = 0
            for niidata in tqdm(train_loader, leave=False):
                img = niidata['images'].to(rank)
                label = niidata['labels'].to(device)
                if params['amp']:
                    with torch.autocast(device_type=device):
                        opti.zero_grad()
                        model.train()
                        pred = model(img)
                        display_dynamicly(img, label, pred)
                        loss_now = loss_fn(pred, label)
                    train_loss += loss_now.item()
                    scaler.scale(loss_now).backward()
                    scaler.step(opti)
                    if params['cos']:
                        schd_lr.step()
                    scaler.update()
                else:
                    opti.zero_grad()
                    model.train()
                    pred = model(img)
                    display_dynamicly(img, label, pred)
                    loss_now = loss_fn(pred, label)
                    train_loss += loss_now.item()
                    loss_now.backward()
                    opti.step()
                    if params['cos']:
                        schd_lr.step()
                print(f'Epoch: {epoch} Loss: {loss_now * params["batch_size"]}')

            test_loss = 0
            for niidatatest in tqdm(test_loader, leave=False):
                with torch.no_grad():
                    img = niidatatest['images'].to(device)
                    label = niidatatest['labels'].to(device)
                    model.eval()
                    pred = model(img)
                    loss = loss_fn(pred, label)
                    test_loss += loss.item()
            train_loss /= (len(train_loader) / params['batch_size'])
            test_loss /= (len(test_loader) / params['batch_size'])
            metrics = {"train_loss": train_loss, 'test_loss': test_loss, "epoch": epoch}
            #             with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
            #                 torch.save(
            #                     model.state_dict(),
            #                     os.path.join(temp_checkpoint_dir, "model.pt")
            #                 )
            #                 ray.train.report(
            #                     metrics,
            #                     checkpoint=ray.train.Checkpoint.from_directory(temp_checkpoint_dir),
            #                 )
            #             if ray.train.get_context().get_world_rank() == 0:
            print(metrics)
            # del test_dataloader
            # del train_dataloader


def display_dynamicly(img, label, pred):
    fig, axs = plt.subplots(2, 3)
    random_index = random.randint(0, pred.shape[0] - 1)
    axs[0, 0].imshow(transformation(pred[random_index][0]), cmap='gray')
    axs[0, 0].title.set_text('Prediction: 0')
    axs[0, 1].imshow(transformation(pred[random_index][1]), cmap='gray')
    axs[0, 1].title.set_text('Prediction: 1')
    axs[0, 2].imshow(transformation(pred[random_index][2]), cmap='gray')
    axs[0, 2].title.set_text('Prediction: 2')
    axs[1, 0].imshow(transformation(label[random_index]), cmap='gray')
    axs[1, 0].title.set_text('Label')
    axs[1, 1].imshow(transformation(img[random_index]), cmap='gray')
    axs[1, 1].title.set_text('Image')
    plt.show()

    # plt.colorbar()
    # Pause for a short period, allowing the plot to update
    plt.pause(0.1)
    #     # Clear the current axes
    axs[0, 0].cla()
    axs[0, 1].cla()
    axs[0, 2].cla()

    axs[1, 0].cla()
    axs[1, 1].cla()
    axs[1, 2].cla()
    plt.clf()


def get_root_dataloader(param: dict):
    rootset = TBADaRoottaset()
    rootloader = DataLoader(rootset, batch_size=1, shuffle=param['shuffle'], drop_last=True)
    return rootloader


def get_nii_dataloader(param: dict, fn, nii_image_path: str, nii_label_path: str):
    # console.print(f'[bold green]Loading {nii_image_path} and {nii_label_path}[/bold green]')
    nii_set = DatasetofEachNII(nii_image_path, nii_label_path)
    train_set, test_set = torch.utils.data.random_split(nii_set, [int(len(nii_set) * param['proportion']),
                                                                  len(nii_set) - int(
                                                                      len(nii_set) * param['proportion'])])
    train_loader = DataLoader(train_set, batch_size=param['batch_size'], shuffle=param['shuffle'], drop_last=True,
                              collate_fn=fn)
    test_loader = DataLoader(test_set, batch_size=param['batch_size'], shuffle=param['shuffle'], drop_last=True,
                             collate_fn=fn)
    return train_loader, test_loader


def getdataloader(param: dict, fn: Callable, root: bool, nii_image_path: str = None, nii_label_path: str = None) -> \
        Union[DataLoader,
        tuple[DataLoader, DataLoader]]:
    if root:
        return get_root_dataloader(param)
    else:
        return get_nii_dataloader(param, fn, nii_image_path, nii_label_path)


def col_fn(images):
    # {"image": png_img, "label": png_label}
    # image: 1*512* 512
    inputs: torch.Tensor = torch.stack([image['image'] for image in images])
    labels: torch.Tensor = torch.stack([image['label'] for image in images])
    return {"images": inputs, "labels": labels}


def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

# if __name__ == "__main__":
#     ray.shutdown()
#     ray.init(include_dashboard=False, dashboard_host='0.0.0.0', ignore_reinit_error=True, dashboard_port=8265)

#     scaling_config = ScalingConfig(num_workers=1, use_gpu=True, resources_per_worker={
#         "CPU": 4,
#         "GPU": 2,
#     }, )
#     torch_config = ray.train.torch.TorchConfig(backend='gloo')
#     trainer = ray.train.torch.TorchTrainer(train_func, train_loop_config=args, scaling_config=scaling_config,
#                                            torch_config=torch_config)
#     result = trainer.fit()
# train_func()
