In [1]:
import os
import SimpleITK as sitk
import glob
import monai
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    AddChanneld,
    SpatialPadd,
    RandRotate90d,
    RandShiftIntensityd,
    EnsureTyped,
    EnsureType,
    MapTransform,
    Resized,
    Invertd,
    ToTensord,
    NormalizeIntensityd,
    RandFlipd,
    Lambdad,
    Activations,
    AsDiscrete,
)
from monai.metrics import ROCAUCMetric
from monai.data import CacheDataset, ThreadDataLoader,DataLoader, Dataset, decollate_batch,load_decathlon_datalist
import torch
from monai.utils import first, set_determinism
import torch.nn as  nn
from torch.nn import Linear,  Softmax
import torch.nn.functional as F
import pandas as pd
from sklearn import metrics
from random import seed, shuffle
torch.multiprocessing.set_sharing_strategy('file_system')


In [33]:

import pandas as pd
df_raw = pd.read_csv('/app/liucd/判定_fill_df.csv')
df_cli = df_raw[['patient_ID', 'T_stage', 'HER2_status', 'NAC_classification', 'ER_percentage', 'PR_percentage', 'Ki_67']]


syf_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/syf/Mixed'
zy_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/zunyi/Mixed'
sd_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/shandong/Mixed'
yizhong_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/yizhong/Mixed'
xian_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/xian/Mixed'

syf_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/syf/Mixed'
zy_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/zunyi/Mixed'
sd_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/shandong/Mixed'
yizhong_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/yizhong/Mixed'
xian_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/xian/Mixed'


train_adcimages = sorted(glob.glob(os.path.join(syf_adcdir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(zy_adcdir,  '*.nii.gz')))


train_dceimages = sorted(glob.glob(os.path.join(syf_dcedir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(zy_dcedir,  '*.nii.gz')))

val_adcimages =  sorted(glob.glob(os.path.join(sd_adcdir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(yizhong_adcdir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(xian_adcdir,  '*.nii.gz'))) 

val_dceimages =  sorted(glob.glob(os.path.join(sd_dcedir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(yizhong_dcedir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(xian_dcedir,  '*.nii.gz'))) 



train_clinical = []
for file_path in train_adcimages:
    p_id = file_path.split('_')[-4]
    clinical_data = df_cli[df_cli['patient_ID'] == int(p_id)].values.tolist()[0][1:]
    train_clinical.append(clinical_data)

val_clinical = []
for file_path in val_adcimages:
    p_id = file_path.split('_')[-4]
    clinical_data = df_cli[df_cli['patient_ID'] == int(p_id)].values.tolist()[0][1:]
    val_clinical.append(clinical_data)


train_dict = [{'image_adc': image_adc, 'image_dce': image_dce, 'clinical': clinical,  'label': int(image_adc.split('_')[-1].replace('.nii.gz', ''))}
                  for image_adc, image_dce, clinical in zip(train_adcimages,  train_dceimages, train_clinical)]
val_dict = [{'image_adc': image_adc, 'image_dce': image_dce, 'clinical': clinical,  'label': int(image_adc.split('_')[-1].replace('.nii.gz', ''))}
                  for image_adc, image_dce, clinical in zip(val_adcimages, val_dceimages, val_clinical)]

print(train_dict[-1])
print(len(train_dict), len(val_dict), len(train_dict + val_dict))

train_transforms = Compose(
        [
            LoadImaged(keys=["image_adc", "image_dce"]),
            EnsureChannelFirstd(keys=["image_adc", "image_dce"]),
            Orientationd(keys=["image_adc", 'image_dce'], axcodes="RAS"),
            Resized(keys=["image_adc"], spatial_size=(64, 64, 16)),
            Resized(keys=["image_dce"], spatial_size=(96, 96, 32)),
            
            NormalizeIntensityd(keys=["image_adc", "image_dce"], nonzero=True, channel_wise=True),
            
            RandFlipd( keys=["image_adc", ], spatial_axis=[0], prob=0.50),
            RandFlipd( keys=["image_adc", ], spatial_axis=[1], prob=0.50),
            RandFlipd( keys=["image_adc", ], spatial_axis=[2], prob=0.50),
            
            RandFlipd( keys=["image_dce", ], spatial_axis=[0], prob=0.50),
            RandFlipd( keys=["image_dce", ], spatial_axis=[1], prob=0.50),
            RandFlipd( keys=["image_dce", ], spatial_axis=[2], prob=0.50),
            
            RandRotate90d(keys=["image_adc", 'image_dce'], prob=0.50, max_k=3 ),
            RandShiftIntensityd( keys=["image_adc", 'image_dce'], offsets=0.10, prob=0.50),
            
            ToTensord(keys=['image_adc', 'image_dce','clinical',  'label'])
        ]
    )

val_transforms = Compose(
        [
            LoadImaged(keys=["image_adc",'image_dce' ]),
            EnsureChannelFirstd(keys=["image_adc", 'image_dce']),
            Orientationd(keys=["image_adc",'image_dce'], axcodes="RAS"),
            Resized(keys=["image_adc"], spatial_size=(64, 64, 16)),
            Resized(keys=["image_dce"], spatial_size=(96, 96, 32)),
            
            NormalizeIntensityd(keys=["image_adc", 'image_dce'], nonzero=True, channel_wise=True),
            ToTensord(keys=['image_adc', 'image_dce','clinical', 'label'])
        ]
    )



{'image_adc': '/app/liucd/deeplearn_dec/DL_dec/data_adc/zunyi/Mixed/2022_10_20_1650366_ADC2_0000_0.nii.gz', 'image_dce': '/app/liucd/deeplearn_dec/DL_dec/data/zunyi/Mixed/2022_10_20_1650366_+C2_0000_0.nii.gz', 'clinical': [3.0, 1.0, 2.0, 0.0, 0.0, 0.2], 'label': 0}
691 619 1310


In [35]:

train_ds = Dataset(data=train_dict, transform=train_transforms,)# cache_rate=1, num_workers=12)
val_ds = Dataset(data=val_dict, transform=val_transforms,) # cache_rate=1, num_workers=12)

# create a training data loader
train_loader = DataLoader(train_ds, batch_size=4, shuffle=True, num_workers=16, pin_memory=True)

# create a validation data loader
val_loader = DataLoader(val_ds, batch_size=8, num_workers=8, pin_memory=True)



In [46]:

class DoubleTower(nn.Module):
    def __init__(self,
                 pretrained_dce='',
                 pretrained_adc='',
                 device = torch.device("cuda"),
                 num_classes=2,
                 fc_hidden_size = 128
                ):
        super().__init__()
        self.pretrained_dce = pretrained_dce
        self.pretrained_adc = pretrained_adc
        self.fc_hidden_size = fc_hidden_size
        self.num_classes = num_classes
        self.device = device

        self.model_dce = monai.networks.nets.ViT(in_channels=1, img_size=(96, 96, 32), 
                              hidden_size=64, num_layers=6, num_heads=8, 
                              patch_size=(16, 16, 16), pos_embed='conv', classification=False)
        self.model_adc = monai.networks.nets.ViT(in_channels=1, img_size=(64, 64, 16), 
                              hidden_size=64, num_layers=6, num_heads=8, 
                              patch_size=(16, 16, 16), pos_embed='conv', classification=False)

    
        self.attn = nn.MultiheadAttention(512, num_heads=8, batch_first=True, device=self.device)

        # self.Linear1 = Linear(1024 + 6, self.num_classes, device=self.device)
        self.Linear1 = Linear(512, self.fc_hidden_size, device=self.device)  # 1024 是 所有下采样特征图globalpool之后拼接的结果
        self.Linear2 = Linear(self.fc_hidden_size + 6, self.num_classes, device=self.device)
        self.dropout = nn.Dropout(0.2)


    def forward(self, x1, x2, structured_data):  # x 是SegResNet的输入影像矩阵

        encode_output1, _ = self.model_dce(x1)
        encode_output2, _ = self.model_adc(x2)

        concatenated = torch.concat([encode_output1, encode_output2], dim=-1)
        
        concatenated = concatenated.unsqueeze(1)
        attn_output, _ = self.attn(concatenated, concatenated, concatenated)

        attn_output = attn_output.squeeze(1)

        fc1 = F.relu(self.Linear1(attn_output))
        fc1 = self.dropout(fc1)

        fc2 = self.Linear2( torch.concat([fc1, structured_data], dim=-1))
        return F.log_softmax(fc2, dim=-1)


# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

my_model = DoubleTower()

x1 = torch.randn(8, 1, 96, 96, 32)  # batch, channel, x, y, z
x2 = torch.randn(8, 1, 64, 64, 16)
cli = torch.randn(8, 6)
output = my_model(x1, x2, cli)
print('output: ', output.shape)

RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 72 but got size 16 for tensor number 1 in the list.

In [38]:
net = monai.networks.nets.ViT(in_channels=1, img_size=(96, 96, 32), 
                              hidden_size=64, num_layers=6, num_heads=8, 
                              patch_size=(16, 16, 16), pos_embed='conv', classification=False)


In [39]:
x = torch.randn(8, 1, 64, 64, 16)
y, temp = net(x)

RuntimeError: The size of tensor a (16) must match the size of tensor b (72) at non-singleton dimension 1

In [31]:
y.shape

torch.Size([8, 72, 64])