In [2]:
import torch
import torch.nn as nn
from torchvision import models

def Double_conv(input_channel, output_channel):
    return nn.Sequential(
        nn.Conv2d(input_channel, output_channel, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(output_channel, output_channel, 3, padding=1),
        nn.ReLU(inplace=True)
    )

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34
    """

    expansion = 1

    def __init__(self, input_channel, output_channel, stride=1):
        super().__init__()

        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_channel * BasicBlock.expansion)
        )

        # shortcut
        self.shortcut = nn.Sequential()

        
        if stride != 1 or input_channel != BasicBlock.expansion * output_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(input_channel, output_channel * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(output_channel * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class BottleNeck(nn.Module):
    
    expansion = 4

    def __init__(self, input_channel, output_channel, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_channel * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or input_channel != output_channel * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(input_channel, output_channel * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(output_channel * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))



class ResNet(nn.Module):

    def __init__(self,in_channel,out_channel, block, num_block):
        super().__init__()

        self.input_channel = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size = 7, stride = 2, padding = 3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        # self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = Double_conv(256 + 512, 256)
        self.dconv_up2 = Double_conv(128 + 256, 128)
        self.dconv_up1 = Double_conv(128 + 64, 64)

        self.dconv_last=nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64,out_channel,1)
        )

    def _make_layer(self, block, output_channel, num_blocks, stride):
       
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.input_channel, output_channel, stride))
            self.input_channel = output_channel * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        conv1 = self.conv1(x)
        temp=self.maxpool(conv1)
        conv2 = self.conv2_x(temp)
        conv3 = self.conv3_x(conv2)
        conv4 = self.conv4_x(conv3)
        bottle = self.conv5_x(conv4)
        # output = self.avg_pool(output)
        # output = output.view(output.size(0), -1)
        # output = self.fc(output)
        x = self.upsample(bottle)
        
        x = torch.cat([x, conv4], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up1(x)
        x=self.upsample(x)
        
        x=torch.cat([x,conv1],dim=1)
        out=self.dconv_last(x)

        return out

    


def resnet18(in_channel,out_channel):
    model=ResNet(in_channel,out_channel,BasicBlock,[2, 2, 2, 2])
    return model



In [None]:
import torch
import torch.nn as nn
from torchvision import models

def Double_conv(input_channel, output_channel):
    return nn.Sequential(
        nn.Conv2d(input_channel, output_channel, 3, padding=1),
        nn.ReLU(inplace=True),
        nn.Conv2d(output_channel, output_channel, 1, padding=1),
        nn.ReLU(inplace=True)
    )

class BasicBlock(nn.Module):
    """Basic Block for resnet 18 and resnet 34
    """

    expansion = 1

    def __init__(self, input_channel, output_channel, stride=1):
        super().__init__()

        # residual function
        self.residual_function = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=3, stride=stride, padding=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel * BasicBlock.expansion, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_channel * BasicBlock.expansion)
        )

        # shortcut
        self.shortcut = nn.Sequential()

        
        if stride != 1 or input_channel != BasicBlock.expansion * output_channel:
            self.shortcut = nn.Sequential(
                nn.Conv2d(input_channel, output_channel * BasicBlock.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(output_channel * BasicBlock.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))


class BottleNeck(nn.Module):
    
    expansion = 4

    def __init__(self, input_channel, output_channel, stride=1):
        super().__init__()
        self.residual_function = nn.Sequential(
            nn.Conv2d(input_channel, output_channel, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel, stride=stride, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(output_channel),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channel, output_channel * BottleNeck.expansion, kernel_size=1, bias=False),
            nn.BatchNorm2d(output_channel * BottleNeck.expansion),
        )

        self.shortcut = nn.Sequential()

        if stride != 1 or input_channel != output_channel * BottleNeck.expansion:
            self.shortcut = nn.Sequential(
                nn.Conv2d(input_channel, output_channel * BottleNeck.expansion, stride=stride, kernel_size=1, bias=False),
                nn.BatchNorm2d(output_channel * BottleNeck.expansion)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.residual_function(x) + self.shortcut(x))



class ResNet(nn.Module):

    def __init__(self,in_channel,out_channel, block, num_block):
        super().__init__()

        self.input_channel = 64

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channel, 64, kernel_size = 7, stride = 2, padding = 3,bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True))
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        
        self.conv2_x = self._make_layer(block, 64, num_block[0], 1)
        self.conv3_x = self._make_layer(block, 128, num_block[1], 2)
        self.conv4_x = self._make_layer(block, 256, num_block[2], 2)
        self.conv5_x = self._make_layer(block, 512, num_block[3], 2)
        # self.fc = nn.Linear(512 * block.expansion, num_classes)
        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.dconv_up3 = Double_conv(256 + 512, 256)
        self.dconv_up2 = Double_conv(128 + 256, 128)
        self.dconv_up1 = Double_conv(128 + 64, 64)

        self.dconv_last=nn.Sequential(
            nn.Conv2d(128, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            nn.Conv2d(64,out_channel,1)
        )

    def _make_layer(self, block, output_channel, num_blocks, stride):
       
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.input_channel, output_channel, stride))
            self.input_channel = output_channel * block.expansion

        return nn.Sequential(*layers)

    def forward(self, x):
        conv1 = self.conv1(x)
        temp=self.maxpool(conv1)
        conv2 = self.conv2_x(temp)
        conv3 = self.conv3_x(conv2)
        conv4 = self.conv4_x(conv3)
        bottle = self.conv5_x(conv4)
        # output = self.avg_pool(output)
        # output = output.view(output.size(0), -1)
        # output = self.fc(output)
        x = self.upsample(bottle)
        
        x = torch.cat([x, conv4], dim=1)

        x = self.dconv_up3(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv3], dim=1)

        x = self.dconv_up2(x)
        x = self.upsample(x)
        
        x = torch.cat([x, conv2], dim=1)

        x = self.dconv_up1(x)
        x=self.upsample(x)
        
        x=torch.cat([x,conv1],dim=1)
        out=self.dconv_last(x)

        return out

    


def resnet18(in_channel,out_channel，pretrain=False):
    model=ResNet(in_channel,out_channel,BasicBlock,[2, 2, 2, 2]）
    return model
