In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile
from thop import clever_format
torch.backends.cudnn.enabled = False

传统卷积

In [15]:
class double_conv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
        super(double_conv2d_bn, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=strides, padding=padding, bias=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=kernel_size, stride=strides, padding=padding, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
    
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = F.relu(self.bn2(self.conv2(out)))
        return out

class deconv2d_bn(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
        super(deconv2d_bn, self).__init__()
        self.conv1 = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=strides, bias=True)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        return out


深度可分离卷积

In [16]:
class depthwise_separable_conv(nn.Module):
    """MobileNet V1风格的深度可分离卷积块"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
        super().__init__()
        # 深度卷积
        self.depthwise = nn.Conv2d(
            in_channels,
            in_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding=padding,
            groups=in_channels,  # 关键参数：分组数=输入通道数
            bias=False
        )
        # 逐点卷积
        self.pointwise = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(in_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        x = F.relu(self.bn1(self.depthwise(x)))  # 深度卷积+BN+ReLU
        x = F.relu(self.bn2(self.pointwise(x)))  # 逐点卷积+BN+ReLU
        return x

class double_conv2d_bn(nn.Module):
    """双深度可分离卷积块（替换原始双标准卷积）"""
    def __init__(self, in_channels, out_channels, kernel_size=3, strides=1, padding=1):
        super().__init__()
        # 第一个深度可分离卷积
        self.conv1 = depthwise_separable_conv(
            in_channels, 
            out_channels,
            kernel_size=kernel_size,
            stride=strides,
            padding=padding
        )
        # 第二个深度可分离卷积
        self.conv2 = depthwise_separable_conv(
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=strides,
            padding=padding
        )

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

class deconv2d_bn(nn.Module):
    """反卷积块（保持原结构不变）"""
    def __init__(self, in_channels, out_channels, kernel_size=2, strides=2):
        super().__init__()
        self.conv1 = nn.ConvTranspose2d(
            in_channels, out_channels,
            kernel_size=kernel_size,
            stride=strides, bias=True
        )
        self.bn1 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        return F.relu(self.bn1(self.conv1(x)))

In [17]:
class UnetS1(nn.Module):
    def __init__(self):
        super(UnetS1, self).__init__()
        # self.mycustomlayer = MyCustomLayer()
        self.layer1_conv = double_conv2d_bn(1, 12)
        self.pointwise = nn.Conv2d(
            2,
            12,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.layer2_conv = double_conv2d_bn(12, 24)
        self.layer3_conv = double_conv2d_bn(24, 12)
        # self.layer3_conv = double_conv2d_bn(12, 12)
        self.layer4_conv = nn.Conv2d(12,3,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(24, 12)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):      
        conv1_2 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1_2, 2)
        
        conv2 = self.layer2_conv(pool1)
        
        convt1 = self.deconv1(conv2)
        concat1 = torch.cat([convt1, conv1_2], dim=1)
        conv3 = self.layer3_conv(concat1)
        
        conv4 = self.layer4_conv(conv3)
        outp = self.sigmoid(conv4)
        return outp
model = UnetS1()
inp = torch.rand(10, 1, 300, 300)
outp = model(inp)
print("输出形状:", outp.shape)

输出形状: torch.Size([10, 3, 300, 300])


In [18]:
class UnetS2(nn.Module):
    def __init__(self):
        super(UnetS2, self).__init__()
        # self.mycustomlayer = MyCustomLayer()
        self.layer1_conv = double_conv2d_bn(1, 12)
        self.pointwise = nn.Conv2d(
            1,
            12,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.layer2_conv = double_conv2d_bn(12, 24)
        self.layer3_conv = double_conv2d_bn(24, 48)
        
        self.layer4_conv = double_conv2d_bn(48, 24)
        self.layer5_conv = double_conv2d_bn(24, 12)
        self.layer6_conv = nn.Conv2d(12,3,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(48, 24)
        self.deconv2 = deconv2d_bn(24, 12)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        conv1_2 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1_2, 2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)
        
        conv3 = self.layer3_conv(pool2)

        convt1 = self.deconv1(conv3)
        concat1 = torch.cat([convt1, conv2], dim=1)
        conv4 = self.layer4_conv(concat1)
        
        convt2 = self.deconv2(conv4)
        concat2 = torch.cat([convt2, conv1_2], dim=1)
        conv5 = self.layer5_conv(concat2)

        conv6 = self.layer6_conv(conv5)
        outp = self.sigmoid(conv6)
        return outp
model = UnetS2()
inp = torch.rand(10, 1, 300, 300)
outp = model(inp)
print("输出形状:", outp.shape)

输出形状: torch.Size([10, 3, 300, 300])


In [19]:
class UnetS3(nn.Module):
    def __init__(self):
        super(UnetS3, self).__init__()
        # self.mycustomlayer = MyCustomLayer()
        self.layer1_conv = double_conv2d_bn(1, 12)
        self.pointwise = nn.Conv2d(
            1,
            12,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.layer2_conv = double_conv2d_bn(12, 24)
        self.layer3_conv = double_conv2d_bn(24, 48)
        self.layer4_conv = double_conv2d_bn(48, 96)
        self.layer5_conv = double_conv2d_bn(96, 48)
        self.layer6_conv = double_conv2d_bn(48, 24)
        self.layer7_conv = double_conv2d_bn(24, 12)
        
        self.layer8_conv = nn.Conv2d(12,3,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(96, 48)
        self.deconv2 = deconv2d_bn(48, 24)
        self.deconv3 = deconv2d_bn(24, 12)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        conv1_2 = self.layer1_conv(x)
        pool1 = F.max_pool2d(conv1_2, 2)
        
        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)
        
        conv3 = self.layer3_conv(pool2)
        pool3 = F.max_pool2d(conv3, 2)

        conv4 = self.layer4_conv(pool3)

        convt1 = self.deconv1(conv4)
        if convt1.shape[2:] != conv3.shape[2:]:
            convt1 = F.interpolate(convt1, size=conv3.shape[2:], mode='bilinear', align_corners=True)
        concat1 = torch.cat([convt1, conv3], dim=1)
        conv5 = self.layer5_conv(concat1)
        
        convt2 = self.deconv2(conv5)
        if convt2.shape[2:] != conv2.shape[2:]:
            convt2 = F.interpolate(convt2, size=conv2.shape[2:], mode='bilinear', align_corners=True)
        concat2 = torch.cat([convt2, conv2], dim=1)
        conv6 = self.layer6_conv(concat2)

        convt3 = self.deconv3(conv6)
        if convt3.shape[2:] != conv1_2.shape[2:]:
            convt3 = F.interpolate(convt3, size=conv1_2.shape[2:], mode='bilinear', align_corners=True)
        concat3 = torch.cat([convt3, conv1_2], dim=1)
        conv7 = self.layer7_conv(concat3)

        conv8 = self.layer8_conv(conv7)
        outp = self.sigmoid(conv8)
        return outp
model = UnetS3()
inp = torch.rand(10, 1, 300, 300)
outp = model(inp)
print("输出形状:", outp.shape)

输出形状: torch.Size([10, 3, 300, 300])


In [20]:
class MetaUnetS1(nn.Module):
    def __init__(self):
        super(MetaUnetS1, self).__init__()
        self.pointwise = nn.Conv2d(
            2,
            12,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.layer2_conv = double_conv2d_bn(12, 24)
        self.layer3_conv = double_conv2d_bn(24, 12)
        self.layer4_conv = nn.Conv2d(12,3,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(24, 12)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        conv1_2 = self.pointwise(x)
        pool1 = F.max_pool2d(conv1_2, 2)
        
        conv2 = self.layer2_conv(pool1)
        convt1 = self.deconv1(conv2)
        concat1 = torch.cat([convt1, conv1_2], dim=1)
        conv3 = self.layer3_conv(concat1)
        
        conv4 = self.layer4_conv(conv3)
        outp = self.sigmoid(conv4)
        return outp

In [21]:
class MetaUnetS2(nn.Module):
    def __init__(self):
        super(MetaUnetS2, self).__init__()
        self.pointwise = nn.Conv2d(
            2,
            12,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=False
        )
        self.layer2_conv = double_conv2d_bn(12, 24)
        self.layer3_conv = double_conv2d_bn(24, 48)
        self.layer4_conv = double_conv2d_bn(48, 24)
        self.layer5_conv = double_conv2d_bn(24, 12)
        self.layer6_conv = nn.Conv2d(12,4,kernel_size=3,
                                     stride=1,padding=1,bias=True)
        
        self.deconv1 = deconv2d_bn(48, 24)
        self.deconv2 = deconv2d_bn(24, 12)
        
        self.sigmoid = nn.Sigmoid()
        
    def forward(self, x):
        conv1_2 = self.pointwise(x)
        pool1 = F.max_pool2d(conv1_2, 2)

        conv2 = self.layer2_conv(pool1)
        pool2 = F.max_pool2d(conv2, 2)
        conv3 = self.layer3_conv(pool2)
        convt1 = self.deconv1(conv3)
        concat1 = torch.cat([convt1, conv2], dim=1)
        conv4 = self.layer4_conv(concat1)
        convt2 = self.deconv2(conv4)
        concat2 = torch.cat([convt2, conv1_2], dim=1)
        conv5 = self.layer5_conv(concat2)
        conv6 = self.layer6_conv(conv5)

        outp = self.sigmoid(conv6)
        return outp
model = MetaUnetS2()
inp = torch.rand(1, 2, 300, 300)
outp = model(inp)
print("输出形状:", outp.shape)

输出形状: torch.Size([1, 4, 300, 300])


In [41]:
import torch
from torch.utils.benchmark import Timer

def benchmark_model(model, input_size, device='cuda', num_threads=1):
    """
    使用PyTorch官方基准测试工具
    """
    model.to(device)
    model.eval()
    
    # 准备输入
    if device == 'cuda':
        x = torch.randn(input_size).cuda()
    else:
        x = torch.randn(input_size)
    
    # 定义测试函数
    def model_inference():
        with torch.no_grad():
            return model(x)
    
    # 创建计时器
    timer = Timer(
        stmt="model_inference()",
        globals={"model_inference": model_inference},
        num_threads=num_threads,
        label="Model Inference",
        description="Benchmark model inference speed"
    )
    
    # 运行基准测试
    result = timer.timeit(100)  # 运行100次
    
    print(result)
    return result

# 使用示例
model=MetaUnetS2()
benchmark_model(model, (1, 2, 300, 300))

<torch.utils.benchmark.utils.common.Measurement object at 0x00000287A2CD4E80>
Model Inference
Benchmark model inference speed
  11.47 ms
  1 measurement, 100 runs , 1 thread


<torch.utils.benchmark.utils.common.Measurement object at 0x00000287A2CD4E80>
Model Inference
Benchmark model inference speed
  11.47 ms
  1 measurement, 100 runs , 1 thread

In [39]:

model = MetaUnetS1()

# 选择设备（GPU 如果可用，否则是 CPU）
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 将模型转移到设备
model.to(device)

# 设置模型为评估模式
model.eval()

input_tensor = torch.randn(1, 2, 400, 400).to(device)
# 将模型移动到设备
model = model.to(device)

# 计算FLOPs和参数量
flops, params = profile(model, inputs=(input_tensor,), verbose=False)
# 将结果转为更易读的格式
flops, params = clever_format([flops, params], '%.3f')

# 打印结果
print(f"运算量：{flops}, 参数量：{params}")


运算量：468.000M, 参数量：3.771K
