In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import data
from numpy import *
import numpy as np
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
from torchvision.transforms import functional as tf
import glob
from PIL import Image
import sys
import os
import time
%matplotlib inline
%config InlineBackend.figure_format = "retina"
import random
def setup_seed(seed):
     torch.manual_seed(seed)
     torch.cuda.manual_seed_all(seed)
     np.random.seed(seed)
     random.seed(seed)
     torch.backends.cudnn.deterministic = True
# 设置随机数种子
setup_seed(88)

print(torch.__version__)
print(torch.cuda.is_available())
print('GPU name: ',torch.cuda.get_device_name(0))

1.12.0
True
GPU name:  NVIDIA GeForce RTX 3060 Ti


In [17]:
#### Unet
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Downsample, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(out_channels),
                                       nn.ReLU(inplace=True),
                                       nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(out_channels),
                                       nn.ReLU(inplace=True)
        )
        self.pool = nn.MaxPool2d(kernel_size = 2)
    
    def forward(self, x, is_pool=True):

        if is_pool: 
            x = self.pool(x)

        x = self.conv_relu(x)

        return x


class Upsample(nn.Module):
    def __init__(self, channels):
        super(Upsample, self).__init__()
        self.conv_relu = nn.Sequential(nn.Conv2d(2*channels, channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(channels),
                                       nn.ReLU(inplace=True),
                                       
                                       nn.Conv2d(channels, channels, kernel_size=3, padding=1),
                                       nn.BatchNorm2d(channels),
                                       nn.ReLU(inplace=True),
        )
        self.upconv = nn.Sequential(nn.ConvTranspose2d(channels, channels//2, kernel_size=3, stride=2,padding=1,output_padding=1)
        )
    
    def forward(self, x):
        x = self.conv_relu(x)
        x = self.upconv(x)
        return x


class Unet_model(nn.Module):
    def __init__(self):
        super(Unet_model, self).__init__()
        self.down1 = Downsample(1,64) # 
        self.down2 = Downsample(64,128)
        self.down3 = Downsample(128,256)
        self.down4 = Downsample(256,512)
        self.down5 = Downsample(512,1024)

        self.up = nn.Sequential(nn.ConvTranspose2d(1024,512,kernel_size=3,stride=2,padding=1,output_padding=1),
                                #nn.Dropout(p=0.5),
                                nn.BatchNorm2d(512),
                                nn.ReLU(inplace=True)
        )

        self.up1 = Upsample(512)
        self.up2 = Upsample(256)
        self.up3 = Upsample(128)

        self.conv_2 = Downsample(128,64)

        self.last = nn.Sequential(nn.Conv2d(64,26,kernel_size=1),
                                  #nn.Dropout(p=0.5)
                                  ) # 
    
    def forward(self, input):
        x1 = self.down1(input, is_pool=False)
        x2 = self.down2(x1)
        x3 = self.down3(x2)
        x4 = self.down4(x3)
        x5 = self.down5(x4)

        x5 = self.up(x5)

        x5 = torch.cat([x4,x5], dim=1) 
        x5 = self.up1(x5) 

        x5 = torch.cat([x3,x5], dim=1)
        x5 = self.up2(x5) 

        x5 = torch.cat([x2,x5], dim=1)
        x5 = self.up3(x5) 

        x5 = torch.cat([x1,x5], dim=1)

        x5 = self.conv_2(x5, is_pool=False)

        x5 = self.last(x5)

        return x5

In [None]:
## AFFU
# CBAM attention module
class ChannelAttentionModule(nn.Module):
    def __init__(self, channel, ratio=16):
        super(ChannelAttentionModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.shared_MLP = nn.Sequential(
            nn.Conv2d(channel, channel // ratio, 1, bias=False),
            nn.ReLU(),
            nn.Conv2d(channel // ratio, channel, 1, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = self.shared_MLP(self.avg_pool(x))
        maxout = self.shared_MLP(self.max_pool(x))
        return self.sigmoid(avgout + maxout)

class SpatialAttentionModule(nn.Module):
    def __init__(self):
        super(SpatialAttentionModule, self).__init__()
        self.conv2d = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=7, stride=1, padding=3)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avgout = torch.mean(x, dim=1, keepdim=True)
        maxout, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avgout, maxout], dim=1)
        out = self.sigmoid(self.conv2d(out))
        return out

class CBAM(nn.Module):
    def __init__(self, channel):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttentionModule(channel)
        self.spatial_attention = SpatialAttentionModule()

    def forward(self, x):
        out = self.channel_attention(x) * x
        out = self.spatial_attention(out) * out
        return out

# attention Unet 5layer
# 注意力模块
class Attention_block(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(Attention_block, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)

        return x * psi
# 金字塔池化层ASPP
class ASPP(nn.Module):
    def __init__(self, ch_in, ch_out, rate=[6, 12, 18]):
        super(ASPP, self).__init__()

        self.aspp_block1 = nn.Sequential(
            nn.Conv2d(
                ch_in, ch_out, 3, stride=1, padding=rate[0], dilation=rate[0]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(ch_out),
        )
        self.aspp_block2 = nn.Sequential(
            nn.Conv2d(
                ch_in, ch_out, 3, stride=1, padding=rate[1], dilation=rate[1]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(ch_out),
        )
        self.aspp_block3 = nn.Sequential(
            nn.Conv2d(
                ch_in, ch_out, 3, stride=1, padding=rate[2], dilation=rate[2]
            ),
            nn.ReLU(inplace=True),
            nn.BatchNorm2d(ch_out),
        )

        self.output = nn.Conv2d(len(rate) * ch_out, ch_out, 1)
        self._init_weights()

    def forward(self, x):
        x1 = self.aspp_block1(x)
        x2 = self.aspp_block2(x)
        x3 = self.aspp_block3(x)
        out = torch.cat([x1, x2, x3], dim=1)
        return self.output(out)
    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

# 残差模块
class Residual_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(Residual_block, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.PReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=True),
            nn.BatchNorm2d(out_channels),
            nn.PReLU(),
        )
        self.conv_skip = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0),
            nn.BatchNorm2d(out_channels),
        )

    def forward(self, x):

        return self.conv_block(x)+self.conv_skip(x)

# 编码连续卷积层
def contracting_block(in_channels, out_channels):
    block = torch.nn.Sequential(
        nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, stride=1, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels),
        nn.Conv2d(kernel_size=(3, 3), in_channels=out_channels, out_channels=out_channels, stride=1, padding=1),
        nn.ReLU(),
        nn.BatchNorm2d(out_channels)
    )
    return block


# 上采样过程中也反复使用了两层卷积的结构
double_conv = contracting_block

# 上采样反卷积模块
class expansive_block(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(expansive_block, self).__init__()

        self.block = nn.Sequential(
            nn.Conv2d(kernel_size=(3, 3), in_channels=in_channels, out_channels=out_channels, stride=1, padding=1),
            nn.ReLU(),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True)
        out = self.block(x)
        return out


def final_block(in_channels, out_channels):
    return nn.Conv2d(kernel_size=1, in_channels=in_channels, out_channels=out_channels, stride=1, padding=0)


#——————————————————————————————————————————————————————————————————————————————————————————————————————————————
#——————————————————————————————————————————————5层 Attention_Unet——————————————————————————————————————————————
class AttUNet(nn.Module):

    def __init__(self, in_channel, out_channel):
        super(AttUNet, self).__init__()
        # Encode
        self.conv_encode1 = Residual_block(in_channels=in_channel, out_channels=64)
        self.conv_pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode2 = Residual_block(in_channels=64, out_channels=128)
        self.conv_pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode3 = Residual_block(in_channels=128, out_channels=256)
        self.conv_pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode4 = Residual_block(in_channels=256, out_channels=512)
        self.conv_pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.conv_encode5 = Residual_block(in_channels=512, out_channels=1024)

        # Decode
        self.conv_decode4 = expansive_block(1024, 512)
        self.att4 = Attention_block(F_g=512, F_l=512, F_int=256)
        self.double_conv4 = double_conv(1024, 512)

        self.conv_decode3 = expansive_block(512, 256)
        self.att3 = Attention_block(F_g=256, F_l=256, F_int=128)
        self.double_conv3 = double_conv(512, 256)

        self.conv_decode2 = expansive_block(256, 128)
        self.att2 = Attention_block(F_g=128, F_l=128, F_int=64)
        self.double_conv2 = double_conv(256, 128)

        self.conv_decode1 = expansive_block(128, 64)
        self.att1 = Attention_block(F_g=64, F_l=64, F_int=32)
        self.double_conv1 = double_conv(128, 64)

        self.final_layer = final_block(64+32+128, out_channel)
        
        #----------------------------------------
        self.UpFromInput = nn.Sequential(nn.Conv2d(1, 32, kernel_size=3, padding=1),
                                        nn.BatchNorm2d(32),
                                        nn.ReLU(inplace=True))
        self.CBAM_upside = CBAM(channel=32)
        

        self.CBAM_layer4 = CBAM(channel=512)
        self.CBAM_layer3 = CBAM(channel=256)
        self.CBAM_layer2 = CBAM(channel=128)
        self.aspp_bridge = ASPP(ch_in=(512+256+128), ch_out=128)
        
        

    def forward(self, x):
        up_bridge = self.UpFromInput(x)
        up_bridge = self.CBAM_upside(up_bridge)
        
        # Encode
        encode_block1 = self.conv_encode1(x)
        encode_pool1 = self.conv_pool1(encode_block1)
        encode_block2 = self.conv_encode2(encode_pool1)
        encode_pool2 = self.conv_pool2(encode_block2)
        encode_block3 = self.conv_encode3(encode_pool2)
        encode_pool3 = self.conv_pool3(encode_block3)
        encode_block4 = self.conv_encode4(encode_pool3)
        encode_pool4 = self.conv_pool4(encode_block4)
        encode_block5 = self.conv_encode5(encode_pool4)

        # Decode
        decode_block4 = self.conv_decode4(encode_block5)
        encode_block4 = self.att4(g=decode_block4, x=encode_block4)
        decode_block4 = torch.cat((encode_block4, decode_block4), dim=1)
        decode_block4 = self.double_conv4(decode_block4)

        decode_block3 = self.conv_decode3(decode_block4)
        encode_block3 = self.att3(g=decode_block3, x=encode_block3)
        decode_block3 = torch.cat((encode_block3, decode_block3), dim=1)
        decode_block3 = self.double_conv3(decode_block3)

        decode_block2 = self.conv_decode2(decode_block3)
        encode_block2 = self.att2(g=decode_block2, x=encode_block2)
        decode_block2 = torch.cat((encode_block2, decode_block2), dim=1)
        decode_block2 = self.double_conv2(decode_block2)

        decode_block1 = self.conv_decode1(decode_block2)
        encode_block1 = self.att1(g=decode_block1, x=encode_block1)
        decode_block1 = torch.cat((encode_block1, decode_block1), dim=1)
        decode_block1 = self.double_conv1(decode_block1)
        
        #------------------------------
        CBAM4 = self.CBAM_layer4(encode_block4)
        CBAM3 = self.CBAM_layer3(encode_block3)
        CBAM2 = self.CBAM_layer2(encode_block2)
        
        CBAM4 = F.interpolate(CBAM4, scale_factor=8, mode='bilinear', align_corners=True) # CHANNEL 512
        CBAM3 = F.interpolate(CBAM3, scale_factor=4, mode='bilinear', align_corners=True) # 256
        CBAM2 = F.interpolate(CBAM2, scale_factor=2, mode='bilinear', align_corners=True) # 128
        
        PRE_OUT = torch.cat((CBAM2,CBAM3,CBAM4),dim=1) # 896 CHANNEL
        OUT = self.aspp_bridge(PRE_OUT) # CHANNEL 128
        
        Upside_out = torch.cat((decode_block1,up_bridge),dim=1) # CHANNEL 96
        
        

        final_layer = self.final_layer(torch.cat((OUT,Upside_out),dim=1))

        return final_layer



In [None]:
### UNet++
class VGGBlock(nn.Module):
    def __init__(self, in_channels, middle_channels, out_channels):
        super().__init__()
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(middle_channels)
        self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        return out

class Up(nn.Module):


    def __init__(self):
        super().__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        # input is CHW
        diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
        diffX = torch.tensor([x2.size()[3] - x1.size()[3]])

        x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
                        diffY // 2, diffY - diffY // 2])
        # if you have padding issues, see
        # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
        # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
        x = torch.cat([x2, x1], dim=1)
        return x


class NestedUNet(nn.Module):
    def __init__(self, num_classes=26, input_channels=1, deep_supervision=False, **kwargs):
        super().__init__()

        nb_filter = [32, 64, 128, 256, 512]

        self.deep_supervision = deep_supervision

        self.pool = nn.MaxPool2d(2, 2)
        self.up = Up()

        self.conv0_0 = VGGBlock(input_channels, nb_filter[0], nb_filter[0])
        self.conv1_0 = VGGBlock(nb_filter[0], nb_filter[1], nb_filter[1])
        self.conv2_0 = VGGBlock(nb_filter[1], nb_filter[2], nb_filter[2])
        self.conv3_0 = VGGBlock(nb_filter[2], nb_filter[3], nb_filter[3])
        self.conv4_0 = VGGBlock(nb_filter[3], nb_filter[4], nb_filter[4])

        self.conv0_1 = VGGBlock(nb_filter[0]+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_1 = VGGBlock(nb_filter[1]+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_1 = VGGBlock(nb_filter[2]+nb_filter[3], nb_filter[2], nb_filter[2])
        self.conv3_1 = VGGBlock(nb_filter[3]+nb_filter[4], nb_filter[3], nb_filter[3])

        self.conv0_2 = VGGBlock(nb_filter[0]*2+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_2 = VGGBlock(nb_filter[1]*2+nb_filter[2], nb_filter[1], nb_filter[1])
        self.conv2_2 = VGGBlock(nb_filter[2]*2+nb_filter[3], nb_filter[2], nb_filter[2])

        self.conv0_3 = VGGBlock(nb_filter[0]*3+nb_filter[1], nb_filter[0], nb_filter[0])
        self.conv1_3 = VGGBlock(nb_filter[1]*3+nb_filter[2], nb_filter[1], nb_filter[1])

        self.conv0_4 = VGGBlock(nb_filter[0]*4+nb_filter[1], nb_filter[0], nb_filter[0])

        if self.deep_supervision:
            self.final1 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final2 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final3 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
            self.final4 = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)
        else:
            self.final = nn.Conv2d(nb_filter[0], num_classes, kernel_size=1)


    def forward(self, input):
        x0_0 = self.conv0_0(input)
        x1_0 = self.conv1_0(self.pool(x0_0))
        x0_1 = self.conv0_1(self.up(x1_0, x0_0))

        x2_0 = self.conv2_0(self.pool(x1_0))
        x1_1 = self.conv1_1(self.up(x2_0, x1_0))
        x0_2 = self.conv0_2(self.up(x1_1, torch.cat([x0_0, x0_1], 1)))

        x3_0 = self.conv3_0(self.pool(x2_0))
        x2_1 = self.conv2_1(self.up(x3_0, x2_0))   
        x1_2 = self.conv1_2(self.up(x2_1, torch.cat([x1_0, x1_1], 1)))
        x0_3 = self.conv0_3(self.up(x1_2, torch.cat([x0_0, x0_1, x0_2], 1)))

        x4_0 = self.conv4_0(self.pool(x3_0))
        x3_1 = self.conv3_1(self.up(x4_0, x3_0))
        x2_2 = self.conv2_2(self.up(x3_1, torch.cat([x2_0, x2_1], 1)))
        x1_3 = self.conv1_3(self.up(x2_2, torch.cat([x1_0, x1_1, x1_2], 1)))
        x0_4 = self.conv0_4(self.up(x1_3, torch.cat([x0_0, x0_1, x0_2, x0_3], 1)))

        if self.deep_supervision:
            output1 = self.final1(x0_1)
            output2 = self.final2(x0_2)
            output3 = self.final3(x0_3)
            output4 = self.final4(x0_4)
            return [output1, output2, output3, output4]

        else:
            output = self.final(x0_4)
            return output



if __name__ == '__main__':
    #tensorboard --logdir logs_model
    net = NestedUNet()

    input = torch.ones(( 16,1, 256, 128))
    y = net(input)
    print(y.shape)
    

In [None]:
## Attention UNet
# Class built to make implementation of the double convolutions easier
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )
    def forward(self, x):
        return self.conv(x)
    
class Attention_block(nn.Module):
    def __init__(self,F_g,F_l,F_int):
        super(Attention_block,self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
            )
        
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1,stride=2,padding=0,bias=True),
            nn.BatchNorm2d(F_int)
        )

        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
            nn.BatchNorm2d(1),
            nn.Sigmoid(),
            nn.ConvTranspose2d(1, 1, kernel_size=2, stride=2)
        )
        
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self,g,x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        if x1.shape != g1.shape:
            x1 = tf.resize(x1, size=g1.shape[2:])
        psi = self.relu(g1+x1)
        psi = self.psi(psi)
        if psi.shape != x.shape:
            psi = tf.resize(psi, size=x.shape[2:])
        return x*psi

class ATT_UNET(nn.Module):
    def __init__(
            self, in_channels=1, out_channels=26, features=[64, 128, 256, 512],
    ):
        super(ATT_UNET, self).__init__()

        # Defining the layers of the network
        # Convolutions on downward half of UNet:
        self.downs = nn.ModuleList()
        self.down_sample = nn.MaxPool2d(kernel_size=2, stride=2)
        self.deepest_conv = DoubleConv(in_channels=features[3],out_channels=features[3]*2)
        self.up_convs = nn.ModuleList()
        self.up_samples = nn.ModuleList()
        self.att_block = nn.ModuleList()
        self.final_conv = nn.Conv2d(in_channels=features[0], out_channels=out_channels, kernel_size=1)

        # Downward part of UNet
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature
        
        # Attention blocks
        for feature in reversed(features):
            self.att_block.append(
                Attention_block(F_g= feature*2,F_l=feature,F_int=feature)
                )
            self.up_convs.append(
                DoubleConv(feature*2, feature)
                )
            self.up_samples.append(
                nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2)
                )


    def forward(self,x):
        # initialise the skip connections variable - this will feed into the attention blocks
        skip_connections = []

        # downward path, identical to UNet
        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.down_sample(x)

        # Deepest convolution layer
        x = self.deepest_conv(x)

        # Readying skip connections tensor
        skip_connections = skip_connections[::-1]
        
        ## upward path with attention gates

        for idx in range(len(self.up_convs)):
            a = self.att_block[idx](g=x,x=skip_connections[idx])
            x = self.up_samples[idx](x)
            if x.shape != a.shape:                       # resizing to allow the concatenation, done at all stages
                x = tf.resize(x, size=a.shape[2:])
            x = torch.cat((a,x),dim=1)
            x = self.up_convs[idx](x)

        x_out = self.final_conv(x)

        return x_out

## Simple test function desinged to test that the UNet is taking in and outputting tensors of the correct size
def test():
    # x is a random tensor representing an input to UNet [batch=1, channels=1, height=321, width=321]
    x = torch.randn((1, 1, 128, 128))
    model = ATT_UNET(in_channels=1, out_channels=26)
    preds = model(x)
    preds_2 = torch.unsqueeze(torch.argmax(preds,dim=1),dim=1)
    print(x.shape)
    print(preds.shape)
    assert preds_2.shape == x.shape

if __name__ == "__main__":
    test()