In [None]:
import torch
import data_reader

In [None]:
from torch import nn
class DoubleConv(nn.Module):
    """(convolution => [BN] => ReLU) * 2"""
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
            nn.LazyBatchNorm2d(mid_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
            nn.LazyBatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )
    def forward(self, x):
        return self.double_conv(x)


In [None]:
from torch import nn
class Downsample_block(nn.Module):
    def __init__(self,in_channels=3,out_channels=64,kernel_size = 3,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.layer1 = DoubleConv(in_channels,out_channels)
        self.pooling = nn.MaxPool2d(kernel_size=2,stride=2)

    def forward(self,x):
        x = self.layer1(x)
        return self.pooling(x)



In [None]:
class Upsampler(nn.Module):
    #先上采样，然后卷积
    def __init__(self, in_channels,out_channels,bilinear = True,*args, **kwargs):
        super().__init__(*args, **kwargs)
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose2d(in_channels , in_channels , kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)
    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)#进行融合裁剪
        return self.conv(x)

class TransUpsample(nn.Module):
    def __init__(self,in_channels,mid_channels,outchannels,kernel_size=2,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.up = nn.ConvTranspose2d(in_channels,in_channels,kernel_size)
        self.conv = DoubleConv(mid_channels,outchannels)
    
    def forward(self,x1,x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        diffX = x2.size()[3] - x1.size()[3]
        x1 = nn.functional.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)#进行融合裁剪
        return self.conv(x)

a = Upsampler(20,30)
x = torch.zeros((1,10,30,30))
y = torch.zeros((1,10,40,40))
a(x,y).shape


In [None]:
class StnAttention(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.localization = nn.Sequential(
            nn.Conv2d(128,64,3),
            nn.MaxPool2d(2,stride=2),
            nn.ReLU(True),
            nn.Conv2d(64,128,3),
            nn.MaxPool2d(2,2),
            nn.ReLU(True)
        )
        self.fc_loc = nn.Sequential(
            nn.Linear(15488,50),
            nn.ReLU(True),
            nn.Linear(50,2*3),
        )
    def forward(self,x):
        xs = self.localization(x)
        xs = xs.view(xs.shape[0],-1)
        theta = self.fc_loc(xs)
        theta = theta.view(-1,2,3)
        grid = nn.functional.affine_grid(theta,x.size())
        x = nn.functional.grid_sample(x,grid)
        return x
        


In [None]:
class Hnet(nn.Module):
    def __init__(self,in_channels,num_classes,*args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.downsampler1 = Downsample_block(in_channels)
        self.downsampler2 = Downsample_block(64,128)
        #self.downsampler3 = Downsample_block(128,256)
        self.bottom = DoubleConv(128,256)
        self.upsampler1 = TransUpsample(256,384,64)
        #self.upsampler2 = Upsampler(384,64)
        #//self.upsampler3 = Upsampler(128,64)
        
        
        self.restrans = TransUpsample(64,128,64,2)
        self.resconv = nn.Conv2d(64,num_classes,1)
        self.res_upsampler = nn.Upsample(scale_factor=2,mode='bilinear',align_corners=True)
        
        self.attention = StnAttention()

    def cut_tensor(self,x, target_tensor):
        shape_tgt = target_tensor.shape[-1]  # 图片是正方形的
        shape_src = x.shape[-1]
        
        # 计算裁剪的偏移量，确保为整数
        delta = int((shape_src - shape_tgt) / 2)
        
        # 如果 `shape_src - shape_tgt` 是奇数，调整裁剪范围
        if (shape_src - shape_tgt) % 2 != 0:
            end_delta = delta + 1
        else:
            end_delta = delta
        
        return x[:, :, delta:shape_src-end_delta, delta:shape_src-end_delta]
    
    def forward(self,_):
        x1 = self.downsampler1(_)
        #x1的形状为 100，100
        x2 = self.downsampler2(x1)
        x3 = self.bottom(x2)
        x4 = self.upsampler1(x3,x2)
        x5 = self.restrans(x4,x1)
        return self.res_upsampler(self.resconv(x5))

x = torch.zeros((1,3,200,200))
a = Hnet(3,4)
a(x).shape

In [None]:
import torch.utils
import torch.utils.data


train_data,test_data = data_reader.get_train_and_test_data()

net = Hnet(3,4)
net = net.to('cuda:0')



In [None]:
#loss_fn = nn.CrossEntropyLoss(reduction='none')
train_iter = torch.utils.data.DataLoader(train_data,16,True,num_workers=4)
test_iter = torch.utils.data.DataLoader(test_data,64,num_workers=4)
epochs = 10
optimizer = torch.optim.Adam(net.parameters(),1e-4,weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss(reduction='none')
losses,train_acc = [],[]
test_acc = []
for epoch in range(epochs):
    net.train()
    for x,y in train_iter:
        x,y = x.to('cuda:0'),y.to('cuda:0')
        optimizer.zero_grad()
        pre = net(x)
        loss = loss_fn(pre,y.squeeze(1).long()).mean(1).mean()
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        losses.append(loss)
        train_acc.append(${func computes acc})
        test_acc.append(${func computes acc})
    print(f"epoch: {epoch} loss: {losses[-1]} train_acc:{train_acc[-1]} test_acc: {test_acc[-1]}")