In [1]:

from pipeline.component.nn import save_to_fate

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%%save_to_fate dataset Fate_Dataset.py
import os
import torch
from federatedml.nn.dataset.base import Dataset
import random
import numpy as np
from torchvision.transforms import transforms
import h5py

class RandomCrop(object):
    """
    Crop randomly the image in a sample
    Args:
    output_size (int): Desired output size
    """

    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c, w, h, d) = image.shape
        w1 = np.random.randint(0, w - self.output_size[0])
        h1 = np.random.randint(0, h - self.output_size[1])
        d1 = np.random.randint(0, d - self.output_size[2])

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        return {'image': image, 'label': label}


class CenterCrop(object):
    def __init__(self, output_size):
        self.output_size = output_size

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        (c,w, h, d) = image.shape

        w1 = int(round((w - self.output_size[0]) / 2.))
        h1 = int(round((h - self.output_size[1]) / 2.))
        d1 = int(round((d - self.output_size[2]) / 2.))

        label = label[w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]
        image = image[:,w1:w1 + self.output_size[0], h1:h1 + self.output_size[1], d1:d1 + self.output_size[2]]

        return {'image': image, 'label': label}


class RandomRotFlip(object):
    """
    Crop randomly flip the dataset in a sample
    Args:
    output_size (int): Desired output size
    """

    def __call__(self, sample):
        image, label = sample['image'], sample['label']
        k = np.random.randint(0, 4)
        image = np.stack([np.rot90(x,k) for x in image],axis=0)
        label = np.rot90(label, k)
        axis = np.random.randint(1, 4)
        image = np.flip(image, axis=axis).copy()
        label = np.flip(label, axis=axis-1).copy()

        return {'image': image, 'label': label}


def augment_gaussian_noise(data_sample, noise_variance=(0, 0.1)):
    if noise_variance[0] == noise_variance[1]:
        variance = noise_variance[0]
    else:
        variance = random.uniform(noise_variance[0], noise_variance[1])
    data_sample = data_sample + np.random.normal(0.0, variance, size=data_sample.shape)
    return data_sample


class GaussianNoise(object):
    def __init__(self, noise_variance=(0, 0.1), p=0.5):
        self.prob = p
        self.noise_variance = noise_variance

    def __call__(self, sample):
        image = sample['image']
        label = sample['label']
        if np.random.uniform() < self.prob:
            image = augment_gaussian_noise(image, self.noise_variance)
        return {'image': image, 'label': label}


class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""
    def __call__(self, sample):
        image = sample['image']
        label = sample['label']

        image = torch.from_numpy(image).float()
        label = torch.from_numpy(label).long()

        return {'image': image, 'label': label}


class BraTS_Dataset(Dataset):
    def __init__(self):
        
        self.paths = None
        self.ids = None
        self.transform = transforms.Compose([
        RandomRotFlip(),
        RandomCrop((160,160,128)),
        GaussianNoise(p=0.1),
        ToTensor()])

    def load(self, path):
        paths = os.listdir(path)
        self.paths = [os.path.join(path, p) for p in paths]
        self.ids = [name.split('_')[1] for name in paths]
        return self
    
    def get_sample_ids(self):  # implement the get sample id interface, simply return ids
        return self.ids

    def __getitem__(self, item):
        h5f = h5py.File(self.paths[item], 'r')
        image = h5f['image'][:]
        label = h5f['label'][:]
        #[0,1,2,4] -> [0,1,2,3]
        label[label == 4] = 3
        # print(image.shape)
        sample = {'image': image, 'label': label}
        if self.transform:
            sample = self.transform(sample)
        return sample['image'], sample['label']

    def __len__(self):
        return len(self.paths)

    def collate(self, batch):
        return [torch.cat(v) for v in zip(*batch)]

In [3]:
%%save_to_fate loss Fate_Loss.py
import torch.nn.functional as F
import torch.nn as nn
import torch
import numpy as np
from einops import rearrange


def cosine_scheduler(base_value, final_value, epochs, niter_per_ep, warmup_epochs=0, start_warmup_value=0.):
    warmup_schedule = np.array([])
    warmup_iters = warmup_epochs * niter_per_ep
    if warmup_epochs > 0:
        warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)

    iters = np.arange(epochs * niter_per_ep - warmup_iters)
    schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))

    schedule = np.concatenate((warmup_schedule, schedule))
    assert len(schedule) == epochs * niter_per_ep
    return schedule


def Dice(output, target, eps=1e-3):
    inter = torch.sum(output * target,dim=(1,2,-1)) + eps
    union = torch.sum(output,dim=(1,2,-1)) + torch.sum(target,dim=(1,2,-1)) + eps * 2
    x = 2 * inter / union
    dice = torch.mean(x)
    return dice


def cal_dice(output, target):
    '''
    output: (b, num_class, d, h, w)  target: (b, d, h, w)
    dice1(ET):label4
    dice2(TC):label1 + label4
    dice3(WT): label1 + label2 + label4
    注,这里的label4已经被替换为3
    '''
    output = torch.argmax(output,dim=1)
    dice1 = Dice((output == 3).float(), (target == 3).float())
    dice2 = Dice(((output == 1) | (output == 3)).float(), ((target == 1) | (target == 3)).float())
    dice3 = Dice((output != 0).float(), (target != 0).float())

    return dice1, dice2, dice3


class Loss(nn.Module):
    def __init__(self, n_classes, weight=None, alpha=0.5):
        "dice_loss_plus_cetr_weighted"
        super(Loss, self).__init__()
        self.n_classes = n_classes
        self.weight = weight.cuda()
        # self.weight = weight
        self.alpha = alpha

    def forward(self, input, target):
        # print(torch.unique(target))
        smooth = 0.01

        input1 = F.softmax(input, dim=1)
        target1 = F.one_hot(target,self.n_classes)
        input1 = rearrange(input1,'b n h w s -> b n (h w s)')
        target1 = rearrange(target1,'b h w s n -> b n (h w s)')

        input1 = input1[:, 1:, :]
        target1 = target1[:, 1:, :].float()

        # 以batch为单位计算loss和dice_loss，据说训练更稳定，那我试试
        inter = torch.sum(input1 * target1)
        union = torch.sum(input1) + torch.sum(target1) + smooth
        dice = 2.0 * inter / union

        loss = F.cross_entropy(input,target, weight=self.weight)

        total_loss = (1 - self.alpha) * loss + (1 - dice) * self.alpha

        return total_loss

In [4]:
%%save_to_fate model Fate_Model.py
import torch
import torch.nn as nn


class InConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(InConv, self).__init__()
        self.conv = DoubleConv(in_ch, out_ch)

    def forward(self, x):
        x = self.conv(x)
        return x

class Down(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Down, self).__init__()
        self.mpconv = nn.Sequential(
            nn.MaxPool3d(2, 2),
            DoubleConv(in_ch, out_ch)
        )

    def forward(self, x):
        x = self.mpconv(x)
        return x

class OutConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(OutConv, self).__init__()
        self.conv = nn.Conv3d(in_ch, out_ch, 1)
        # self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.conv(x)
        # x = self.sigmoid(x)
        return x

class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x = self.conv(x)
        return x

class Up(nn.Module):
    def __init__(self, in_ch, skip_ch,out_ch):
        super(Up, self).__init__()
        self.up = nn.ConvTranspose3d(in_ch, in_ch, kernel_size=2, stride=2)
        self.conv = DoubleConv(in_ch+skip_ch, out_ch)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        x = torch.cat([x2, x1], dim=1)
        x = self.conv(x)
        return x


class UNet(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(UNet, self).__init__()
        features = [32,64,128,256]

        self.inc = InConv(in_channels, features[0])
        self.down1 = Down(features[0], features[1])
        self.down2 = Down(features[1], features[2])
        self.down3 = Down(features[2], features[3])
        self.down4 = Down(features[3], features[3])

        self.up1 = Up(features[3], features[3], features[2])
        self.up2 = Up(features[2], features[2], features[1])
        self.up3 = Up(features[1], features[1], features[0])
        self.up4 = Up(features[0], features[0], features[0])
        self.outc = OutConv(features[0], num_classes)

    def forward(self, x):
        x1 = self.inc(x)
        x2 = self.down1(x1)
        x3 = self.down2(x2)
        x4 = self.down3(x3)
        x5 = self.down4(x4)

        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        x = self.outc(x)
        return x

In [1]:
! ls ../../../../examples/data/Dataset

BraTS2021_00000_mri_norm2.h5  BraTS2021_01042_mri_norm2.h5
BraTS2021_00002_mri_norm2.h5  BraTS2021_01043_mri_norm2.h5
BraTS2021_00003_mri_norm2.h5  BraTS2021_01044_mri_norm2.h5
BraTS2021_00005_mri_norm2.h5  BraTS2021_01045_mri_norm2.h5
BraTS2021_00006_mri_norm2.h5  BraTS2021_01046_mri_norm2.h5
BraTS2021_00008_mri_norm2.h5  BraTS2021_01047_mri_norm2.h5
BraTS2021_00009_mri_norm2.h5  BraTS2021_01048_mri_norm2.h5
BraTS2021_00011_mri_norm2.h5  BraTS2021_01049_mri_norm2.h5
BraTS2021_00012_mri_norm2.h5  BraTS2021_01050_mri_norm2.h5
BraTS2021_00014_mri_norm2.h5  BraTS2021_01051_mri_norm2.h5
BraTS2021_00016_mri_norm2.h5  BraTS2021_01052_mri_norm2.h5
BraTS2021_00017_mri_norm2.h5  BraTS2021_01053_mri_norm2.h5
BraTS2021_00018_mri_norm2.h5  BraTS2021_01054_mri_norm2.h5
BraTS2021_00019_mri_norm2.h5  BraTS2021_01055_mri_norm2.h5
BraTS2021_00020_mri_norm2.h5  BraTS2021_01056_mri_norm2.h5
BraTS2021_00021_mri_norm2.h5  BraTS2021_01057_mri_norm2.h5
BraTS2021_00022_mri_norm2.h5  BraTS2021_01058_mri_norm2.

In [4]:
from federatedml.nn.dataset.Fate_Dataset import BraTS_Dataset

In [5]:
ds = BraTS_Dataset()

In [6]:
# load MNIST data and check 
ds.load('../../../../examples/data/brats_host/')
print(len(ds))

625


In [5]:
print(ds[0])

(tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]],

         ...,

         [[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0.,

In [7]:
import torch as t
from torch import nn
from pipeline import fate_torch_hook
from pipeline.component import HomoNN
from pipeline.backend.pipeline import PipeLine
from pipeline.component import Reader, Evaluation, DataTransform
from pipeline.interface import Data, Model

t = fate_torch_hook(t)

In [8]:
import os
# bind data path to name & namespace
fate_project_path = os.path.abspath('../../../../')
host = 9999
guest = 10000
arbiter = 9999
pipeline = PipeLine().set_initiator(role='guest', party_id=guest).set_roles(guest=guest, host=host,
                                                                            arbiter=arbiter)

data_0 = {"name": "brats_guest", "namespace": "test"}
data_1 = {"name": "brats_host", "namespace": "test"}

data_path_0 = fate_project_path + '/examples/data/brats_guest'
data_path_1 = fate_project_path + '/examples/data/brats_host'
pipeline.bind_table(name=data_0['name'], namespace=data_0['namespace'], path=data_path_0)
pipeline.bind_table(name=data_1['name'], namespace=data_1['namespace'], path=data_path_1)

{'namespace': 'test', 'table_name': 'brats_host'}

In [9]:
# 定义reader
reader_0 = Reader(name="reader_0")
reader_0.get_party_instance(role='guest', party_id=guest).component_param(table=data_0)
reader_0.get_party_instance(role='host', party_id=host).component_param(table=data_1)

In [10]:
from pipeline.component.homo_nn import DatasetParam, TrainerParam

dataset_param = DatasetParam(dataset_name='Fate_Dataset')

# your loss class
weights_list = [0.2, 0.3, 0.25, 0.25]
weights_tensor = t.tensor(weights_list)
loss = t.nn.CustLoss(loss_module_name='Fate_Loss', class_name='Loss', n_classes=4, weight=weights_list)

model = t.nn.Sequential(
    # the class_num=10 is the initialzation parameter for my model
    t.nn.CustModel(module_name='Fate_Model', class_name='UNet', num_classes=4, in_channels=4)
)

nn_component = HomoNN(name='nn_0',
                      model=model, # my cust model
                      loss=loss,
                      optimizer=t.optim.SGD(model.parameters(),momentum=0.9, lr=1e-4, weight_decay=5e-4),
                      dataset=dataset_param,
                      trainer=TrainerParam(trainer_name='fedavg_trainer', epochs=20, batch_size=1, validation_freqs=5, cuda=0),
                      torch_seed=21 # global random seed
                      )

In [11]:
pipeline.add_component(reader_0)
pipeline.add_component(nn_component, data=Data(train_data=reader_0.output.data))
pipeline.add_component(Evaluation(name='eval_0', eval_type='multi'), data=Data(data=nn_component.output.data))

<pipeline.backend.pipeline.PipeLine at 0x7f1b707de460>

In [12]:
pipeline.compile()
pipeline.fit()

[32m2023-11-14 20:04:42.893[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m83[0m - [1mJob id is 202311142004377897060
[0m
[32m2023-11-14 20:04:42.905[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:00[0m
[32m2023-11-14 20:04:43.919[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:01[0m
[32m2023-11-14 20:04:44.934[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m98[0m - [1m[80D[1A[KJob is still waiting, time elapse: 0:00:02[0m
[0mm2023-11-14 20:04:46.977[0m | [1mINFO    [0m | [36mpipeline.utils.invoker.job_submitter[0m:[36mmonitor_job_status[0m:[36m125[0m - [1m
[32m2023-11-14 20:04:46.979[0m | [1mINFO    [0m | 