In [19]:
import os
import SimpleITK as sitk
import glob
import monai
from monai.transforms import (

    AsDiscrete,
    RandAdjustContrastd,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    AddChanneld,
    SpatialPadd,
    RandRotate90d,
    RandShiftIntensityd,
    EnsureTyped,
    EnsureType,
    MapTransform,
    Resized,
    Invertd,
    ToTensord,
    NormalizeIntensityd,
    RandFlipd,
    Lambdad,
    Activations,
    AsDiscrete,
)
from monai.metrics import ROCAUCMetric
from monai.data import CacheDataset, ThreadDataLoader,DataLoader, Dataset, decollate_batch,load_decathlon_datalist
import torch
import torch.nn as  nn
from torch.nn import Linear,  Softmax
import torch.nn.functional as F
from monai.utils import first, set_determinism
from random import shuffle, seed
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('./log/tensorboard')

torch.multiprocessing.set_sharing_strategy('file_system')
set_determinism(seed=1)


class DoubleTower(nn.Module):
    def __init__(self,
                 pretrained_dce='',
                 pretrained_adc='',
                 device = torch.device("cuda"),
                 num_classes=2,
                 fc_hidden_size = 256
                ):
        super().__init__()
        self.pretrained_dce = pretrained_dce
        self.pretrained_adc = pretrained_adc
        self.fc_hidden_size = fc_hidden_size
        self.num_classes = num_classes
        self.device = device

        self.model_dce = monai.networks.nets.resnet34(spatial_dims=3, n_input_channels=1, num_classes=2, feed_forward=False).to(self.device)
        self.model_adc = monai.networks.nets.resnet34(spatial_dims=3, n_input_channels=1, num_classes=2, feed_forward=False).to(self.device)

        if  pretrained_dce != '':
            dce_dict = self.model_dce.state_dict()
            dce_pretrain = torch.load(self.pretrained_dce, map_location=self.device)
            dce_pretrain_dict = {k:v for k, v in dce_pretrain.items() if  k in  dce_dict.keys()}
            dce_dict.update(dce_pretrain_dict)
            self.model_dce.load_state_dict(dce_dict)

        if  pretrained_adc !='':
            adc_dict = self.model_adc.state_dict()
            adc_pretrain = torch.load(self.pretrained_adc, map_location=self.device)
            adc_pretrain_dict = {k:v for k, v in adc_pretrain.items() if  k in  adc_dict.keys()}
            adc_dict.update(adc_pretrain_dict)
            self.model_adc.load_state_dict(adc_dict)

        self.attn = nn.MultiheadAttention(512, num_heads=8, batch_first=True, device=self.device)

        # self.Linear1 = Linear(1024 + 6, self.num_classes, device=self.device)
        self.Linear1 = Linear(512, self.fc_hidden_size, device=self.device)  # 1024 是 所有下采样特征图globalpool之后拼接的结果
        self.Linear2 = Linear(self.fc_hidden_size, self.num_classes, device=self.device)
        self.dropout = nn.Dropout(0.2)


    def forward(self, x1, x2, rad, structured_data):  # x 是SegResNet的输入影像矩阵

        encode_output1 = self.model_dce(x1)
        encode_output2 = self.model_dce(x2)

        concatenated = encode_output1 * encode_output2
        
        concatenated = concatenated.unsqueeze(1)
        attn_output, _ = self.attn(concatenated, concatenated, concatenated)

        attn_output = attn_output.squeeze(1)

        fc1 = F.relu(self.Linear1(attn_output))
        fc1 = self.dropout(fc1)
        print(fc1.shape)
        fc2 = self.Linear2( torch.concat([fc1, structured_data], dim=-1))
        return F.log_softmax(fc2, dim=-1)


# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")


In [20]:

my_model = DoubleTower()

x1 = torch.randn(8, 1, 96, 96, 32)  # batch, channel, x, y, z
x2 = torch.randn(8, 1, 64, 64, 16)
cli = torch.randn(8, 6)
radiomics = torch.randn(8, 9)
output = my_model(x1.cuda(), x2.cuda(), radiomics.cuda(), cli.cuda())
print('output: ', output.shape)

torch.Size([8, 256])


RuntimeError: mat1 and mat2 shapes cannot be multiplied (8x262 and 256x2)

In [21]:
my_model

DoubleTower(
  (model_dce): ResNet(
    (conv1): Conv3d(1, 64, kernel_size=(7, 7, 7), stride=(1, 2, 2), padding=(3, 3, 3), bias=False)
    (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool3d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn2): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): ResNetBlock(
        (conv1): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
        (bn1): BatchNorm3d(64, eps

In [22]:
for name, param in my_model.named_parameters():
    if name != 'Linear2.weight' and name != 'Linear2.bias':
        param.requires_grad = False

In [27]:

dce_dict = my_model.state_dict()
dce_pretrain = torch.load('/app/liucd/deeplearn_dec/DL_multi/NewTrain2Val3/net_selfattn/attn_concat/FreezeWeight/model_0.80_0.847.pth')
dce_pretrain_dict = {k:v for k, v in dce_pretrain.items() if k != 'Linear2.weight' and k != 'Linear2.bias':}
dce_dict.update(dce_pretrain_dict)
my_model.load_state_dict(dce_dict)

In [31]:
dce_pretrain.keys()

odict_keys(['model_dce.conv1.weight', 'model_dce.bn1.weight', 'model_dce.bn1.bias', 'model_dce.bn1.running_mean', 'model_dce.bn1.running_var', 'model_dce.bn1.num_batches_tracked', 'model_dce.layer1.0.conv1.weight', 'model_dce.layer1.0.bn1.weight', 'model_dce.layer1.0.bn1.bias', 'model_dce.layer1.0.bn1.running_mean', 'model_dce.layer1.0.bn1.running_var', 'model_dce.layer1.0.bn1.num_batches_tracked', 'model_dce.layer1.0.conv2.weight', 'model_dce.layer1.0.bn2.weight', 'model_dce.layer1.0.bn2.bias', 'model_dce.layer1.0.bn2.running_mean', 'model_dce.layer1.0.bn2.running_var', 'model_dce.layer1.0.bn2.num_batches_tracked', 'model_dce.layer1.1.conv1.weight', 'model_dce.layer1.1.bn1.weight', 'model_dce.layer1.1.bn1.bias', 'model_dce.layer1.1.bn1.running_mean', 'model_dce.layer1.1.bn1.running_var', 'model_dce.layer1.1.bn1.num_batches_tracked', 'model_dce.layer1.1.conv2.weight', 'model_dce.layer1.1.bn2.weight', 'model_dce.layer1.1.bn2.bias', 'model_dce.layer1.1.bn2.running_mean', 'model_dce.layer

In [32]:
dce_dict.keys()

odict_keys(['model_dce.conv1.weight', 'model_dce.bn1.weight', 'model_dce.bn1.bias', 'model_dce.bn1.running_mean', 'model_dce.bn1.running_var', 'model_dce.bn1.num_batches_tracked', 'model_dce.layer1.0.conv1.weight', 'model_dce.layer1.0.bn1.weight', 'model_dce.layer1.0.bn1.bias', 'model_dce.layer1.0.bn1.running_mean', 'model_dce.layer1.0.bn1.running_var', 'model_dce.layer1.0.bn1.num_batches_tracked', 'model_dce.layer1.0.conv2.weight', 'model_dce.layer1.0.bn2.weight', 'model_dce.layer1.0.bn2.bias', 'model_dce.layer1.0.bn2.running_mean', 'model_dce.layer1.0.bn2.running_var', 'model_dce.layer1.0.bn2.num_batches_tracked', 'model_dce.layer1.1.conv1.weight', 'model_dce.layer1.1.bn1.weight', 'model_dce.layer1.1.bn1.bias', 'model_dce.layer1.1.bn1.running_mean', 'model_dce.layer1.1.bn1.running_var', 'model_dce.layer1.1.bn1.num_batches_tracked', 'model_dce.layer1.1.conv2.weight', 'model_dce.layer1.1.bn2.weight', 'model_dce.layer1.1.bn2.bias', 'model_dce.layer1.1.bn2.running_mean', 'model_dce.layer