In [None]:
import os
import SimpleITK as sitk
import glob
import monai
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    AddChanneld,
    SpatialPadd,
    RandRotate90d,
    RandShiftIntensityd,
    EnsureTyped,
    EnsureType,
    MapTransform,
    Resized,
    Invertd,
    ToTensord,
    NormalizeIntensityd,
    RandFlipd,
    Lambdad,
    Activations,
    AsDiscrete,
)
from monai.metrics import ROCAUCMetric
from monai.data import CacheDataset, ThreadDataLoader,DataLoader, Dataset, decollate_batch,load_decathlon_datalist
import torch
from monai.utils import first, set_determinism
import torch.nn as  nn
from torch.nn import Linear,  Softmax
import torch.nn.functional as F

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
torch.multiprocessing.set_sharing_strategy('file_system')
set_determinism(seed=1)


In [None]:
import pandas as pd
df_raw = pd.read_csv('判定_fill_df.csv')
df_cli = df_raw[['patient_ID', 'T_stage', 'HER2_status', 'NAC_classification', 'ER_percentage', 'PR_percentage', 'Ki_67']]


In [None]:
syf1_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/syf_stage1/Mixed'
syf2_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/syf_stage2/Mixed'
zy_adcdir = '/app/liucd/deeplearn_dec/DL_dec/data_adc/zunyi/Mixed'

syf1_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/syf_stage1/Mixed'
syf2_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/syf_stage2/Mixed'
zy_dcedir = '/app/liucd/deeplearn_dec/DL_dec/data/zunyi/Mixed'

train_adcimages = sorted(glob.glob(os.path.join(syf1_adcdir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(syf2_adcdir,  '*.nii.gz'))) 

train_dceimages = sorted(glob.glob(os.path.join(syf1_dcedir,  '*.nii.gz'))) + \
                 sorted(glob.glob(os.path.join(syf2_dcedir,  '*.nii.gz'))) 

val_adcimages = sorted(glob.glob(os.path.join(zy_adcdir,  '*.nii.gz'))) 
val_dceimages =  sorted(glob.glob(os.path.join(zy_dcedir,  '*.nii.gz')))


train_clinical = []
for file_path in train_adcimages:
    p_id = file_path.split('_')[-4]
    clinical_data = df_cli[df_cli['patient_ID'] == int(p_id)].values.tolist()[0][1:]
    train_clinical.append(clinical_data)

val_clinical = []
for file_path in val_adcimages:
    p_id = file_path.split('_')[-4]
    clinical_data = df_cli[df_cli['patient_ID'] == int(p_id)].values.tolist()[0][1:]
    val_clinical.append(clinical_data)
    

train_dict = [{'image_adc': image_adc, 'image_dce': image_dce, 'clinical': clinical,  'label': int(image_adc.split('_')[-1].replace('.nii.gz', ''))} 
                  for image_adc, image_dce, clinical in zip(train_adcimages,  train_dceimages, train_clinical)]
val_dict = [{'image_adc': image_adc, 'image_dce': image_dce, 'clinical': clinical,  'label': int(image_adc.split('_')[-1].replace('.nii.gz', ''))} 
                  for image_adc, image_dce, clinical in zip(val_adcimages, val_dceimages, val_clinical)]

print(train_dict[-1])
len(train_dict), len(val_dict), len(train_dict + val_dict)


{'image_adc': '/app/liucd/deeplearn_dec/DL_dec/data_adc/zunyi/Mixed/2022_10_20_1650366_ADC2_0000_0.nii.gz', 'image_dce': '/app/liucd/deeplearn_dec/DL_dec/data/zunyi/Mixed/2022_10_20_1650366_+C2_0000_0.nii.gz', 'clinical': [3.0, 1.0, 2.0, 0.0, 0.0, 0.2], 'label': 0}


(656, 225, 881)

In [None]:
train_transforms = Compose(
        [
            LoadImaged(keys=["image_adc", "image_dce"]),
            EnsureChannelFirstd(keys=["image_adc", "image_dce"]),
            Orientationd(keys=["image_adc", 'image_dce'], axcodes="RAS"),
            Resized(keys=["image_adc"], spatial_size=(64, 64, 16)),
            Resized(keys=["image_dce"], spatial_size=(96, 96, 32)),
            
            NormalizeIntensityd(keys=["image_adc", "image_dce"], nonzero=True, channel_wise=True),
            
            RandFlipd( keys=["image_adc", ], spatial_axis=[0], prob=0.50),
            RandFlipd( keys=["image_adc", ], spatial_axis=[1], prob=0.50),
            RandFlipd( keys=["image_adc", ], spatial_axis=[2], prob=0.50),
            
            RandFlipd( keys=["image_dce", ], spatial_axis=[0], prob=0.50),
            RandFlipd( keys=["image_dce", ], spatial_axis=[1], prob=0.50),
            RandFlipd( keys=["image_dce", ], spatial_axis=[2], prob=0.50),
            
            RandRotate90d(keys=["image_adc", 'image_dce'], prob=0.50, max_k=3 ),
            RandShiftIntensityd( keys=["image_adc", 'image_dce'], offsets=0.10, prob=0.50),
            
            ToTensord(keys=['image_adc', 'image_dce','clinical',  'label'])
        ]
    )

val_transforms = Compose(
        [
            LoadImaged(keys=["image_adc",'image_dce' ]),
            EnsureChannelFirstd(keys=["image_adc", 'image_dce']),
            Orientationd(keys=["image_adc",'image_dce'], axcodes="RAS"),
            Resized(keys=["image_adc"], spatial_size=(64, 64, 16)),
            Resized(keys=["image_dce"], spatial_size=(96, 96, 32)),
            
            NormalizeIntensityd(keys=["image_adc", 'image_dce'], nonzero=True, channel_wise=True),
            ToTensord(keys=['image_adc', 'image_dce','clinical', 'label'])
        ]
    )


train_ds = CacheDataset(data=train_dict, transform=val_transforms, cache_rate=1.0, num_workers=24)
val_ds = CacheDataset(data=val_dict, transform=val_transforms, cache_rate=1.0, num_workers=24)

Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████| 656/656 [00:53<00:00, 12.21it/s]
Loading dataset: 100%|███████████████████████████████████████████████████████████████████████████████| 225/225 [00:11<00:00, 19.69it/s]


In [None]:
# create a training data loader
train_loader = DataLoader(train_ds, batch_size=12, num_workers=16, pin_memory=True)

# create a validation data loader
val_loader = DataLoader(val_ds, batch_size=12, num_workers=16, pin_memory=True)

In [None]:

# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class DoubleTower(nn.Module):
    def __init__(self, 
                 pretrained_dce='', 
                 pretrained_adc='', 
                 device = torch.device("cuda"),
                 num_classes=2, 
                 fc_hidden_size = 128
                ):
        super().__init__()
        self.pretrained_dce = pretrained_dce
        self.pretrained_adc = pretrained_adc
        self.fc_hidden_size = fc_hidden_size
        self.num_classes = num_classes
        self.device = device
        
        self.model_dce = monai.networks.nets.resnet34(spatial_dims=3, n_input_channels=1, num_classes=2, feed_forward=False).to(self.device)
        self.model_adc = monai.networks.nets.resnet34(spatial_dims=3, n_input_channels=1, num_classes=2, feed_forward=False).to(self.device)
        
        if  pretrained_dce != '':
            dce_dict = self.model_dce.state_dict()
            dce_pretrain = torch.load(self.pretrained_dce, map_location=self.device)
            dce_pretrain_dict = {k:v for k, v in dce_pretrain.items() if  k in  dce_dict.keys()}
            dce_dict.update(dce_pretrain_dict)
            self.model_dce.load_state_dict(dce_dict)

        if  pretrained_adc !='':
            adc_dict = self.model_adc.state_dict()
            adc_pretrain = torch.load(self.pretrained_adc, map_location=self.device)
            adc_pretrain_dict = {k:v for k, v in adc_pretrain.items() if  k in  adc_dict.keys()}
            adc_dict.update(adc_pretrain_dict)
            self.model_adc.load_state_dict(adc_dict)
        
        
        # self.Linear1 = Linear(1024 + 6, self.num_classes, device=self.device) 
        self.Linear1 = Linear(1024, self.fc_hidden_size, device=self.device)  # 1024 是 所有下采样特征图globalpool之后拼接的结果        
        self.Linear2 = Linear(self.fc_hidden_size + 6, self.num_classes, device=self.device)  
        
    
    def forward(self, x1, x2, structured_data):  # x 是SegResNet的输入影像矩阵
        
        encode_output1 = self.model_dce(x1)
        encode_output2 = self.model_dce(x2)
        
        concatenated = torch.concat([encode_output1, encode_output2], dim=-1)
        
        fc1 = F.relu(self.Linear1(concatenated)) 
        # fc1 = nn.Dropout(0.2)(fc1)
       
        fc2 = self.Linear2( torch.concat([fc1, structured_data], dim=-1))
        return F.log_softmax(fc2, dim=-1)



In [15]:

dce_pretrain_path = ''
adc_pretrain_path = ''

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = DoubleTower(dce_pretrain_path, adc_pretrain_path, device = device)

pretrained_path = './Dropout/best_metric_model_classification3d_dict.pth'
model.load_state_dict(torch.load(pretrained_path, map_location=device))


<All keys matched successfully>

In [21]:
post_pred = Compose([Activations(softmax=True)])
post_label = Compose([AsDiscrete(to_onehot=2)])

loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
auc_metric = ROCAUCMetric()

# start a typical PyTorch training
val_interval = 1
best_metric = -1
best_metric_epoch = -1
max_epochs = 150
for epoch in range(max_epochs):
 
    model.train()
    epoch_loss = 0
    val_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        input_dce, input_adc, input_clinical, labels = batch_data["image_dce"].to(device), batch_data['image_adc'].to(device), batch_data["clinical"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(input_dce, input_adc, input_clinical)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size

    epoch_loss /= step
    print(f"epoch {epoch + 1} average  train loss: {epoch_loss:.4f}")
    
    if (epoch + 1) % val_interval == 0:
       
        model.eval()
        with torch.no_grad():
            y_pred = torch.tensor([], dtype=torch.float32, device=device)
            y = torch.tensor([], dtype=torch.long, device=device)
            
            step2 = 0
            for val_data in val_loader:
                step2 += 1
                val_dce, val_adc, val_clinical, val_labels = val_data["image_dce"].to(device),val_data["image_adc"].to(device), val_data["clinical"].to(device), val_data["label"].to(device)
                val_output = model(val_dce, val_adc, val_clinical)
                y_pred = torch.cat([y_pred, val_output], dim=0)
                y = torch.cat([y, val_labels], dim=0)
                val_loss += loss_function(val_output, val_labels).item()
                
            val_loss /= step2
            
            acc_value = torch.eq(y_pred.argmax(dim=1), y)
            acc_metric = acc_value.sum().item() / len(acc_value)
            y_onehot = [post_label(i) for i in decollate_batch(y, detach=False)]
            y_pred_act = [post_pred(i) for i in decollate_batch(y_pred)]
            auc_metric(y_pred_act, y_onehot)
            auc_result = auc_metric.aggregate()
            auc_metric.reset()
            del y_pred_act, y_onehot
            if auc_result > best_metric:
                best_metric = auc_result
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_classification3d_dict.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format(
                    epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch
                )
            )
            print(val_output)
    if epoch > 3:
        break
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")


epoch 1 average  train loss: 0.5412
saved new best metric model
current epoch: 1 current accuracy: 0.7020 current AUC: 0.7969 best accuracy: 0.7969 at epoch 1
tensor([[-0.6239, -0.7676],
        [-0.1130, -2.2361]], device='cuda:1')
epoch 2 average  train loss: 0.4585
saved new best metric model
current epoch: 2 current accuracy: 0.6887 current AUC: 0.8037 best accuracy: 0.8037 at epoch 2
tensor([[-7.7733e-02, -2.5931e+00],
        [-1.5589e-03, -6.4646e+00]], device='cuda:1')
epoch 3 average  train loss: 0.3625
current epoch: 3 current accuracy: 0.6887 current AUC: 0.7943 best accuracy: 0.8037 at epoch 2
tensor([[-8.9241e-03, -4.7235e+00],
        [-1.4490e-03, -6.5376e+00]], device='cuda:1')
epoch 4 average  train loss: 0.3429
current epoch: 4 current accuracy: 0.6755 current AUC: 0.7691 best accuracy: 0.8037 at epoch 2
tensor([[-1.4233e-04, -8.8576e+00],
        [ 0.0000e+00, -1.8217e+01]], device='cuda:1')
epoch 5 average  train loss: 0.2827
current epoch: 5 current accuracy: 0.685

In [68]:

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model = DoubleTower(dce_pretrain_path, adc_pretrain_path, device = device)

pretrained_path = './best_metric_model_classification3d_dict.pth'
model.load_state_dict(torch.load(pretrained_path, map_location=device))
print(model.state_dict()['model_dce.conv1.weight'][0, 0, 0, 0, :3])

model.eval()
with torch.no_grad():
    y_pred = torch.tensor([], dtype=torch.float32, device=device)
    y = torch.tensor([], dtype=torch.long, device=device)

    step2 = 0
    for val_data in val_loader:
        step2 += 1
        val_dce, val_adc, val_clinical, val_labels = val_data["image_dce"].to(device),val_data["image_adc"].to(device), val_data["clinical"].to(device), val_data["label"].to(device)
        
        val_output = model(val_dce, val_adc, val_clinical)
        print(val_output)
        y_pred = torch.cat([y_pred, val_output], dim=0)
        y = torch.cat([y, val_labels], dim=0)
        val_loss += loss_function(val_output, val_labels).item()

    val_loss /= step2

    acc_value = torch.eq(y_pred.argmax(dim=1), y)
    acc_metric = acc_value.sum().item() / len(acc_value)
    y_onehot = [post_label(i) for i in decollate_batch(y, detach=False)]
    y_pred_act = [post_pred(i) for i in decollate_batch(y_pred)]
    auc_metric(y_pred_act, y_onehot)
    auc_result = auc_metric.aggregate()
    auc_metric.reset()
    del y_pred_act, y_onehot
    
    print(
        "current epoch: {} current accuracy: {:.4f} current AUC: {:.4f} best accuracy: {:.4f} at epoch {}".format(
            epoch + 1, acc_metric, auc_result, best_metric, best_metric_epoch
        )
    )
    print(val_output)


tensor([-0.0018,  0.0067,  0.0166], device='cuda:1')
tensor([[-9.5228e-03, -4.6588e+00],
        [-1.7582e-04, -8.6462e+00],
        [-2.5987e-05, -1.0559e+01],
        [-2.4063e-03, -6.0308e+00],
        [-6.3670e-01, -7.5297e-01],
        [-2.4522e-02, -3.7204e+00],
        [-4.8411e-01, -9.5775e-01],
        [-6.4882e-03, -5.0410e+00],
        [-1.2420e-03, -6.6917e+00],
        [-5.4836e-06, -1.2119e+01],
        [-1.2413e-02, -4.3952e+00],
        [-1.7677e-04, -8.6408e+00]], device='cuda:1')
tensor([[-3.1168e-04, -8.0736e+00],
        [-2.9440e-01, -1.3664e+00],
        [-1.2053e-02, -4.4244e+00],
        [-1.0921e-03, -6.8203e+00],
        [-4.7563e-05, -9.9541e+00],
        [-1.0266e-02, -4.5840e+00],
        [-1.7068e-01, -1.8521e+00],
        [-8.4742e-03, -4.7750e+00],
        [-1.6891e-04, -8.6861e+00],
        [-3.8628e-04, -7.8591e+00],
        [-4.9232e-05, -9.9199e+00],
        [-1.6700e-04, -8.6976e+00]], device='cuda:1')
tensor([[-7.1954e-03, -4.9379e+00],
        [-3

In [55]:
current epoch: 2 current accuracy: 0.7318 current AUC: 0.8203 best accuracy: 0.8213 at epoch 1

SyntaxError: invalid syntax (2887209575.py, line 1)

In [59]:
dce_dict = model.state_dict()

tensor(-0.0018, device='cuda:1')