In [21]:
import torch
import torch.nn as nn
import torchvision.models as models

class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=8):
        super(AttentionBlock, self).__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        attn = self.channel_attention(x)
        return x * attn

class GroupConv(nn.Module):
    """Grouped convolution with different dilation rates and attention"""
    def __init__(self, in_channels, out_channels, groups):
        super(GroupConv, self).__init__()
        assert in_channels % groups == 0, "in_channels must be divisible by groups"
        assert out_channels % groups == 0, "out_channels must be divisible by groups"
        self.groups = groups
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels // groups, out_channels // groups, kernel_size=3, padding=d, dilation=d)
            for d in range(1, groups + 1)
        ])
        self.attentions = nn.ModuleList([
            AttentionBlock(out_channels // groups)
            for _ in range(groups)
        ])
        self.integrate_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        conv_x = [conv(split) for conv, split in zip(self.convs, split_x)]
        attn_x = [attn(conv_out) for attn, conv_out in zip(self.attentions, conv_x)]
        x = torch.cat(attn_x, dim=1)
        x = self.integrate_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class Up(nn.Module):
    """Upscaling then double conv"""
    def __init__(self, in_channels, out_channels):
        super(Up, self).__init__()
        if in_channels == 256:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 4, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels // 2, out_channels)
        else:
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        print(f'Upsampled x1 shape: {x1.shape}, x2 shape: {x2.shape}')
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)

class CustomResNet50UNet(nn.Module):
    def __init__(self, n_classes):
        super(CustomResNet50UNet, self).__init__()
        self.resnet50 = models.resnet50(pretrained=True)
        self.resnet50.conv1 = nn.Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
        
        # 保留 ResNet50 的所有層，直到展平層之前
        self.encoder = nn.Sequential(*list(self.resnet50.children())[:-2])

        # 分组卷积层，使用四个不同的膨胀率
        self.group_conv = GroupConv(2048, 2048, groups=4)

        self.up1 = Up(2048, 1024)
        self.up2 = Up(1024, 512)
        self.up3 = Up(512, 256)
        self.up4 = Up(256, 64)
        
        self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder
        x1 = self.resnet50.conv1(x)
        x1 = self.resnet50.bn1(x1)
        x1 = self.resnet50.relu(x1)
        x2 = self.resnet50.maxpool(x1)
        print(f'x1 shape: {x1.shape}')
        
        x2 = self.resnet50.layer1(x2)
        print(f'x2 shape: {x2.shape}')
        x3 = self.resnet50.layer2(x2)
        print(f'x3 shape: {x3.shape}')
        x4 = self.resnet50.layer3(x3)
        print(f'x4 shape: {x4.shape}')
        x5 = self.resnet50.layer4(x4)
        print(f'x5 shape: {x5.shape}')
        
        # Bridge
        x5 = self.group_conv(x5)
        print(f'Bridge x5 shape: {x5.shape}')
        
        # Decoder
        x = self.up1(x5, x4)
        x = self.up2(x, x3)
        x = self.up3(x, x2)
        x = self.up4(x, x1)
        
        x = self.final_conv(x)
        return x

# 測試模型
model = CustomResNet50UNet(n_classes=1)
model = model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
input_tensor = torch.randn(1, 3, 224, 512).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
output = model(input_tensor)
print(output.shape)




x1 shape: torch.Size([1, 64, 112, 256])
x2 shape: torch.Size([1, 256, 56, 128])
x3 shape: torch.Size([1, 512, 28, 64])
x4 shape: torch.Size([1, 1024, 14, 32])
x5 shape: torch.Size([1, 2048, 7, 16])
Bridge x5 shape: torch.Size([1, 2048, 7, 16])
Upsampled x1 shape: torch.Size([1, 1024, 14, 32]), x2 shape: torch.Size([1, 1024, 14, 32])
Upsampled x1 shape: torch.Size([1, 512, 28, 64]), x2 shape: torch.Size([1, 512, 28, 64])
Upsampled x1 shape: torch.Size([1, 256, 56, 128]), x2 shape: torch.Size([1, 256, 56, 128])
Upsampled x1 shape: torch.Size([1, 64, 112, 256]), x2 shape: torch.Size([1, 64, 112, 256])
torch.Size([1, 1, 112, 256])


In [23]:
#seg_model.py
import torch.nn as nn
import torchvision.models as models

class AttentionBlock(nn.Module):
    def __init__(self, in_channels, reduction_ratio=8):
        super(AttentionBlock, self).__init__()
        self.channel_attention = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_channels, in_channels // reduction_ratio, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // reduction_ratio, in_channels, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        attn = self.channel_attention(x)
        return x * attn

class GroupConv(nn.Module):
    """Grouped convolution with different dilation rates and attention"""
    def __init__(self, in_channels, out_channels, groups):
        super(GroupConv, self).__init__()
        assert in_channels % groups == 0, "in_channels must be divisible by groups"
        assert out_channels % groups == 0, "out_channels must be divisible by groups"
        self.groups = groups
        self.convs = nn.ModuleList([
            nn.Conv2d(in_channels // groups, out_channels // groups, kernel_size=3, padding=d, dilation=d)
            for d in range(1, groups + 1)
        ])
        self.attentions = nn.ModuleList([
            AttentionBlock(out_channels // groups)
            for _ in range(groups)
        ])
        self.integrate_conv = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        split_x = torch.split(x, x.size(1) // self.groups, dim=1)
        conv_x = [conv(split) for conv, split in zip(self.convs, split_x)]
        attn_x = [attn(conv_out) for attn, conv_out in zip(self.attentions, conv_x)]
        x = torch.cat(attn_x, dim=1)
        x = self.integrate_conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class ResUnet(nn.Module):
    def __init__(self, num_classes=1, pretrained=True):
        super(ResUnet, self).__init__()
        self.base_model = models.resnet50(pretrained=pretrained)
        self.base_layers = list(self.base_model.children())

        self.layer0 = nn.Sequential(*self.base_layers[:3])  # 64
        self.layer1 = nn.Sequential(*self.base_layers[3:5])  # 256
        self.layer2 = self.base_layers[5]  # 512
        self.layer3 = self.base_layers[6]  # ｛
        self.layer4 = self.base_layers[7]  # 2048

        # 分组卷积层，使用四个不同的膨胀率
        self.group_conv = GroupConv(2048, 2048, groups=4)

        self.upsample4 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample3 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.upsample1 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.conv_up4 = nn.Conv2d(2048 + 1024, 1024, kernel_size=3, padding=1)
        self.conv_up3 = nn.Conv2d(1024 + 512, 512, kernel_size=3, padding=1)
        self.conv_up2 = nn.Conv2d(512 + 256, 256, kernel_size=3, padding=1)
        self.conv_up1 = nn.Conv2d(256 + 64, 64, kernel_size=3, padding=1)
        
        self.conv_last = nn.Conv2d(64, num_classes, kernel_size=1)

        self.relu = nn.ReLU()

    def forward(self, x):
        layer0_out = self.layer0(x)
        layer1_out = self.layer1(layer0_out)
        layer2_out = self.layer2(layer1_out)
        layer3_out = self.layer3(layer2_out)
        layer4_out = self.layer4(layer3_out)

        layer4_out = self.group_conv(layer4_out)

        x = self.upsample4(layer4_out)
        x = torch.cat([x, layer3_out], dim=1)
        x = self.relu(self.conv_up4(x))

        x = self.upsample3(x)
        x = torch.cat([x, layer2_out], dim=1)
        x = self.relu(self.conv_up3(x))

        x = self.upsample2(x)
        x = torch.cat([x, layer1_out], dim=1)
        x = self.relu(self.conv_up2(x))

        x = self.upsample1(x)
        x = torch.cat([x, layer0_out], dim=1)
        x = self.relu(self.conv_up1(x))
        x = self.upsample1(x)
        x = self.conv_last(x)
        return x

# Initialize the model
model = ResUnet(num_classes=1, pretrained=False)
model = model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
input_tensor = torch.randn(1, 3, 224, 512).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
output = model(input_tensor)
print(output.shape)

torch.Size([1, 1, 224, 512])
