In [40]:
import torch.nn as nn
import torch, torchvision

class MIL_Atn_Network(nn.Module):
    def __init__(self, num_classes):
        super(MIL_Atn_Network, self).__init__()
        # import timm
        # self.base_model = timm.create_model('resnet18', pretrained=True,  in_chans=1)
        base_model = torchvision.models.resnet18(weights=False)
        state_dict = torch.load('resnet34-b627a593.pth')
        base_model.load_state_dict(state_dict)
        
        self.feature_extractor = nn.Sequential(*list(base_model.children())[:-1])
        self.fc = torch.nn.Linear(base_model.fc.in_features, num_classes)
        self.atn_fc = nn.Linear(base_model.fc.in_features, 1)  # 全连接层用于计算注意力权重

    def  forward(self, x):
        if x.shape[1] == 1:
            x = x.repeat(1, 3, 1, 1, 1)  # 变成3通道
            
        x = x.permute(0, 4, 1, 2, 3)  
        batch_size, patch_z, channel, patch_x, patch_y  = x.size()
        
        x = x.reshape(batch_size*patch_z, channel, patch_x, patch_y) # [batch*patch_z*channel, patch_x, patch_y] # channel = 1
        
        x = self.feature_extractor(x)
    
        x = x.view(batch_size, patch_z, -1)
        
        attention_scores = torch.sigmoid(self.atn_fc(x)) # [batch_size, num_small_images, 1]
        x = torch.sum(x * attention_scores, dim=1)
        print(x.shape)
        x = self.fc(x)

        return x
                                
        
# device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

my_model = MIL_Atn_Network(num_classes=2)

x1 = torch.randn(8, 1, 96, 96, 32)  # batch, channel, x, y, z
output = my_model(x1)

print('output: ', output.shape)


torch.Size([8, 512])
output:  torch.Size([8, 2])


In [17]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  