# Model building and loading trained weights

In [1]:
import torch
from collections import OrderedDict
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
    in_channels = args.__getattribute__('in_channels')
except:
    in_channels = 1
    import math
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    from collections import OrderedDict


    class _DenseLayer(nn.Sequential):

        def __init__(self, num_input_features, growth_rate, bn_size, drop_rate):
            super().__init__()
            self.add_module('norm1', nn.BatchNorm3d(num_input_features))
            self.add_module('relu1', nn.ReLU(inplace=True))
            self.add_module(
                'conv1',
                nn.Conv3d(num_input_features,
                          bn_size * growth_rate,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False))
            self.add_module('norm2', nn.BatchNorm3d(bn_size * growth_rate))
            self.add_module('relu2', nn.ReLU(inplace=True))
            self.add_module(
                'conv2',
                nn.Conv3d(bn_size * growth_rate,
                          growth_rate,
                          kernel_size=3,
                          stride=1,
                          padding=1,
                          bias=False))
            self.drop_rate = drop_rate

        def forward(self, x):
            new_features = super().forward(x)
            if self.drop_rate > 0:
                new_features = F.dropout(new_features,
                                         p=self.drop_rate,
                                         training=self.training)
            return torch.cat([x, new_features], 1)


    class _DenseBlock(nn.Sequential):

        def __init__(self, num_layers, num_input_features, bn_size, growth_rate,
                     drop_rate):
            super().__init__()
            for i in range(num_layers):
                layer = _DenseLayer(num_input_features + i * growth_rate,
                                    growth_rate, bn_size, drop_rate)
                self.add_module('denselayer{}'.format(i + 1), layer)


    class _Transition(nn.Sequential):

        def __init__(self, num_input_features, num_output_features):
            super().__init__()
            self.add_module('norm', nn.BatchNorm3d(num_input_features))
            self.add_module('relu', nn.ReLU(inplace=True))
            self.add_module(
                'conv',
                nn.Conv3d(num_input_features,
                          num_output_features,
                          kernel_size=1,
                          stride=1,
                          bias=False))
            self.add_module('pool', nn.MaxPool3d(kernel_size=1, stride=2))


    class DenseNet(nn.Module):
        """
        Densenet-BC model class

        Args:
            growth_rate (int) - how many filters to add each layer (k in paper)
            block_config (list of 4 ints) - how many layers in each pooling block
            num_init_features (int) - the number of filters to learn in the first convolution layer
            bn_size (int) - multiplicative factor for number of bottle neck layers
              (i.e. bn_size * k features in the bottleneck layer)
            drop_rate (float) - dropout rate after each dense layer
            num_classes (int) - number of classification classes
        """

        def __init__(self,
                     in_channels=1,
                     conv1_t_size=5,
                     conv1_t_stride=2,
                     no_max_pool=False,
                     growth_rate=32,
                     block_config=(6, 12, 24, 16),
                     num_init_features=64,
                     bn_size=4,
                     drop_rate=0,
                     num_classes=2,
                     **kwargs):

            super().__init__()

            # First convolution
            self.features = [('conv1',
                              nn.Conv3d(in_channels,
                                        num_init_features,
                                        kernel_size=(conv1_t_size, 5, 5),
                                        stride=(conv1_t_stride, 2, 2),
                                        padding=(1, 1, 1),
                                        bias=False)),
                             ('norm1', nn.BatchNorm3d(num_init_features)),
                             ('relu1', nn.ReLU(inplace=True))]
            if not no_max_pool:
                self.features.append(
                    ('pool1', nn.AvgPool3d(kernel_size=3, stride=2, padding=1)))

            self.features = nn.Sequential(OrderedDict(self.features))
            self.attention_module = AttentionModule(num_init_features)

            num_features = num_init_features
            self.blocks_and_transitions = nn.ModuleList()

            for i, num_layers in enumerate(block_config):
                # Each denseblock
                block = _DenseBlock(num_layers=num_layers,
                                    num_input_features=num_features,
                                    bn_size=bn_size,
                                    growth_rate=growth_rate,
                                    drop_rate=drop_rate)
                self.blocks_and_transitions.append(block)
                num_features += num_layers * growth_rate

                # Each transition
                if i != len(block_config) - 1:  # 不在最后一个DenseBlock后添加Transition
                    trans = _Transition(num_input_features=num_features,
                                        num_output_features=num_features // 2)
                    self.blocks_and_transitions.append(trans)
                    num_features //= 2

            # Final batch norm
            self.final_norm = nn.BatchNorm3d(num_features)

            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    m.weight = nn.init.kaiming_normal_(m.weight, mode='fan_out')
                elif isinstance(m, nn.BatchNorm3d) or isinstance(m, nn.BatchNorm2d):
                    m.weight.data.fill_(1)
                    m.bias.data.zero_()

            # Linear layer
            self.classifier = nn.Linear(num_features, num_classes)

            for m in self.modules():
                if isinstance(m, nn.Conv3d):
                    nn.init.kaiming_normal_(m.weight,
                                            mode='fan_out',
                                            nonlinearity='relu')
                elif isinstance(m, nn.BatchNorm3d):
                    nn.init.constant_(m.weight, 1)
                    nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.Linear):
                    nn.init.constant_(m.bias, 0)

        def forward(self, x, mask):
            x = self.features(x)
            x = self.attention_module(x, mask)
            for i, block in enumerate(self.blocks_and_transitions):
                x = block(x)
            x = self.final_norm(x)
            x = F.relu(x, inplace=True)
            x = F.adaptive_avg_pool3d(x, output_size=(1, 1, 1)).view(x.size(0), -1)
            # Apply classifier
            x = self.classifier(x)
            return x


    def generate_model(model_depth, custom_block_config=None, **kwargs):
        assert model_depth in [121, 169, 201, 264, 'my']

        # Traditional DenseNet or custom architecture
        if custom_block_config is None:
            if model_depth == 121:
                block_config = (6, 12, 24, 16)
            elif model_depth == 169:
                block_config = (6, 12, 32, 32)
            elif model_depth == 201:
                block_config = (6, 12, 48, 32)
            elif model_depth == 264:
                block_config = (6, 12, 64, 48)
        else:
            block_config = custom_block_config

        model = DenseNet(num_init_features=64,
                         growth_rate=32,
                         block_config=block_config,
                         **kwargs)
        return model


    def DenseNet121(**kwargs):
        return generate_model(121, **kwargs)


    def DenseNet169(**kwargs):
        return generate_model(169, **kwargs)


    def DenseNet201(**kwargs):
        return generate_model(201, **kwargs)


    def DenseNet264(**kwargs):
        return generate_model(264, **kwargs)
import torch
import torch.nn as nn
import torch.nn.functional as F
import os

class AttentionModule(nn.Module):
    def __init__(self, num_input_features):
        super(AttentionModule, self).__init__()
        self.attention_coefficient = nn.Parameter(torch.ones([1], dtype=torch.float32))
        self.softplus = nn.Softplus()
        self.conv1 = nn.Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
        self.norm1 = nn.BatchNorm3d(64)
        self.ReLU = nn.ReLU(inplace=True)
        self.pool1 = nn.AvgPool3d(kernel_size=3, stride=2, padding=1)

    def forward(self, feature_map, mask):
  
        mask = torch.where(mask > 0, torch.tensor(1.0, device=mask.device), torch.tensor(0.0, device=mask.device))

        mask_processed = self.conv1(mask)
        mask_processed = self.norm1(mask_processed)
        mask_processed = self.ReLU(mask_processed)
        mask_processed = self.pool1(mask_processed)
        mask_processed = torch.where(mask_processed > 0, torch.tensor(1.0, device=mask_processed.device),
                                     torch.tensor(0.0, device=mask_processed.device))
        # Softplus activation weights
        positive_attention_coefficient = self.softplus(self.attention_coefficient)+1
        #habitat attention feature map
        return feature_map *  positive_attention_coefficient * mask_processed
custom_model = generate_model("my", custom_block_config=(2,4,4,2),drop_rate=0.2)



model=custom_model.to(device)  
# Kaiming/He initialization
import torch.nn as nn
def init_weights(m):
    if isinstance(m, nn.Conv3d):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.BatchNorm3d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)
    elif isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        nn.init.constant_(m.bias, 0)

# apply initialization
model.apply(init_weights)
#Loading trained weights 
model.load_state_dict(torch.load('C:/Users/XuRan/Desktop/mode_spilt_benchMark-final/Trained model.pth'))


<All keys matched successfully>

# Print model

In [2]:
model

DenseNet(
  (features): Sequential(
    (conv1): Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (pool1): AvgPool3d(kernel_size=3, stride=2, padding=1)
  )
  (attention_module): AttentionModule(
    (softplus): Softplus(beta=1, threshold=20)
    (conv1): Conv3d(1, 64, kernel_size=(5, 5, 5), stride=(2, 2, 2), padding=(1, 1, 1), bias=False)
    (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (ReLU): ReLU(inplace=True)
    (pool1): AvgPool3d(kernel_size=3, stride=2, padding=1)
  )
  (blocks_and_transitions): ModuleList(
    (0): _DenseBlock(
      (denselayer1): _DenseLayer(
        (norm1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
        (conv1): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), pa