In [2]:
import torch
import torch.nn as nn

In [3]:
input_img = torch.randn(1, 3, 256, 256)

In [131]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, stem=False, middle=False):
        super().__init__()

        if stem:
            self.conv3x3_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
            self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
        else:
            self.conv3x3_1 = nn.Conv2d(in_channels//2, out_channels, kernel_size=3, padding='same')
            self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.conv5x5_1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding='same')
        self.conv5x5_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.conv1x1_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same')
        self.conv1x1_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.relu = nn.ReLU(replace=True)

        self.down = nn.MaxPool2d(kernel_size=2, stride=2)

        self.middle = middle

    def forward(self, x_branche1, x_branche2):

        x1_3 = self.conv3x3_1(x_branche2) # to concat with x5_1
        x1_3 = self.relu(x1_3)
        x1_3 = self.conv3x3_2(x1_3)
        x1_3 = self.relu(x1_3)
        skip2 = x1_3
        x1_3_down = self.down(skip2)

        x1_5 = self.conv5x5_1(x_branche1)
        x1_5 = self.relu(x1_5)
        x1_5 = self.conv5x5_2(x1_5)
        x1_5 = self.relu(x1_5)

        x1_1 = self.conv1x1_1(x_branche1)
        x1_1 = self.relu(x1_1)
        x1_1 = self.conv1x1_2(x1_1)
        x1_1 = self.relu(x1_1)

        if self.middle:
            return torch.cat((x1_5, x1_1, x1_3), 1)

        x5_1 = torch.cat((x1_5, x1_1), 1) # for skip connection, to concat with x1_3
        
        x5_1_down = self.down(x5_1) # for moving to the next encoder

        skip1 = torch.cat((x1_3, x5_1), 1) # for skip connection, to concat with x3_3

        return skip1, skip2, x1_3_down, x5_1_down

class Decoder(nn.Module):
    def __init__(self, filters=[32, 64, 128, 256, 512], n_classes=1):
        super().__init__()
        self.filters = filters
        self.n_classes = n_classes
        # Decoder 1
        self.up1 = nn.ConvTranspose2d(in_channels=self.filters[4]*3, out_channels=self.filters[4], kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3x3_1_dec1 = nn.Conv2d(self.filters[3]*3, self.filters[3]*3, kernel_size=3, padding='same')
        self.conv3x3_2_dec1 = nn.Conv2d(self.filters[4]*3, self.filters[3], kernel_size=3, padding='same')

        # Decoder 2
        self.up2 = nn.ConvTranspose2d(in_channels=self.filters[3], out_channels=self.filters[2], kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3x3_1_dec2 = nn.Conv2d(self.filters[3], self.filters[3], kernel_size=3, padding='same')
        self.conv3x3_2_dec2 = nn.Conv2d(self.filters[2]*5, self.filters[2], kernel_size=3, padding='same')

        # Decoder 3
        self.up3 = nn.ConvTranspose2d(in_channels=self.filters[2], out_channels=self.filters[1], kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3x3_1_dec3 = nn.Conv2d(self.filters[1]*2, self.filters[1], kernel_size=3, padding='same')
        self.conv3x3_2_dec3 = nn.Conv2d(self.filters[1]*4, self.filters[1], kernel_size=3, padding='same')

        # Decoder 4
        self.up4 = nn.ConvTranspose2d(in_channels=self.filters[1], out_channels=self.filters[0], kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3x3_1_dec4 = nn.Conv2d(self.filters[0]*2, self.filters[0], kernel_size=3, padding='same')
        self.conv3x3_2_dec4 = nn.Conv2d(self.filters[0]*4, self.n_classes, kernel_size=3, padding='same')

        self.relu = nn.ReLU(replace=True)

        
    def forward(self, x, skip_connections):
        
        self.skip2_enc4 = skip_connections[0]
        self.skip1_enc4 = skip_connections[1]

        self.skip2_enc3 = skip_connections[2]
        self.skip1_enc3 = skip_connections[3]

        self.skip2_enc2 = skip_connections[4]
        self.skip1_enc2 = skip_connections[5]

        self.skip2_enc1 = skip_connections[6]
        self.skip1_enc1 = skip_connections[7]

       # Decoder 1
        x = self.up1(x)
        x = torch.cat((x, self.skip2_enc4), 1)
        x = self.conv3x3_1_dec1(x)
        x = self.relu(x)
        x = torch.cat((x, self.skip1_enc4), 1)
        x = self.conv3x3_2_dec1(x)
        x = self.relu(x)

        # Decoder 2
        x = self.up2(x)
        x = torch.cat((x, self.skip2_enc3), 1)
        x = self.conv3x3_1_dec2(x)
        x = self.relu(x)
        x = torch.cat((x, self.skip1_enc3), 1)
        x = self.conv3x3_2_dec2(x)
        x = self.relu(x)

        # Decoder 3
        x = self.up3(x)
        x = torch.cat((x, self.skip2_enc2), 1)
        x = self.conv3x3_1_dec3(x)
        x = self.relu(x)
        x = torch.cat((x, self.skip1_enc2), 1)
        x = self.conv3x3_2_dec3(x)
        x = self.relu(x)

        # Decoder 4
        x = self.up4(x)
        x = torch.cat((x, self.skip2_enc1), 1)
        x = self.conv3x3_1_dec4(x)
        x = self.relu(x)
        x = torch.cat((x, self.skip1_enc1), 1)
        x = self.conv3x3_2_dec4(x)
        x = self.relu(x)

        return x

class STAN(nn.Module):
    """
    Class to build STAN architecture: https://arxiv.org/ftp/arxiv/papers/2002/2002.01034.pdf.
    :param in_channels: int, the number of channels of the input image, default is 3
    :param n_classes: int, the number of classes of the segmentation task, default is 1
    :param filters: list[int], the number of filters for each layer, default is [32, 64, 128, 256, 512]
    """

    def __init__(self, in_channels=3, n_classes=1):
        super(STAN, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes
        self.filters = [32, 64, 128, 256, 512]


        # Encoder
        self.enc1 = Encoder(in_channels, self.filters[0], stem=True)
        self.enc2 = Encoder(self.filters[1], self.filters[1])
        self.enc3 = Encoder(self.filters[2], self.filters[2])
        self.enc4 = Encoder(self.filters[3], self.filters[3])

        # Bottleneck (middle)
        self.middle = Encoder(self.filters[4], self.filters[4], middle=True)

        # Decoder
        self.decoder = Decoder(filters=self.filters, n_classes=self.n_classes)

        

    def forward(self, x):
        skip1_enc1, skip2_enc1, x3_1_down, x5_1_down = self.enc1(x, x)
        skip1_enc2, skip2_enc2, x3_1_down, x5_1_down = self.enc2(x5_1_down, x3_1_down)
        skip1_enc3, skip2_enc3, x3_1_down, x5_1_down = self.enc3(x5_1_down, x3_1_down)
        skip1_enc4, skip2_enc4, x3_1_down, x5_1_down = self.enc4(x5_1_down, x3_1_down)

        x = self.middle(x5_1_down, x3_1_down) # torch.Size([1, 1536, 16, 16])
        
        # Decoder
        x = self.decoder(x, [skip2_enc4, skip1_enc4, skip2_enc3, skip1_enc3, skip2_enc2, skip1_enc2, skip2_enc1, skip1_enc1])

        return  x
       

stan = STAN()
x = stan(input_img)
print(x.shape)
# print("middle_x.shape: ", middle_x.shape)
# print("skip1.shape: ", skip1.shape)
# print("skip2.shape: ", skip2.shape)
# print("x1_3_down.shape: ", x1_3_down.shape)
# print("x5_1_down.shape: ", x5_1_down.shape)


torch.Size([1, 1, 256, 256])


In [114]:
class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, stem=False):
        super().__init__()

        if stem:
            self.conv3x3_1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
            self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')
        else:
            self.conv3x3_1 = nn.Conv2d(in_channels//2, out_channels, kernel_size=3, padding='same')
            self.conv3x3_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.conv5x5_1 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding='same')
        self.conv5x5_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.conv1x1_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same')
        self.conv1x1_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

    def forward(self, x_branche1, x_branche2):

        x1_3 = self.conv3x3_1(x_branche2) # to concat with x5_1
        x1_3 = self.conv3x3_2(x1_3)
        skip2 = x1_3
        x1_3_down = self.down(skip2)

        x1_5 = self.conv5x5_1(x_branche1)
        x1_5 = self.conv5x5_2(x1_5)

        x1_1 = self.conv1x1_1(x_branche1)
        x1_1 = self.conv1x1_2(x1_1)

        x5_1 = torch.cat((x1_5, x1_1), 1) # for skip connection, to concat with x1_3
        
        x5_1_down = self.down(x5_1) # for moving to the next encoder

        skip1 = torch.cat((x1_3, x5_1), 1) # for skip connection, to concat with x3_3

        return skip1, skip2, x1_3_down, x5_1_down

class STAN(nn.Module):
    def __init__(self, in_channels=3, n_classes=1):
        super(STAN, self).__init__()
        self.in_channels = in_channels
        self.n_classes = n_classes


        # encoder 1
        self.conv3x3_1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding='same')
        self.conv3x3_2 = nn.Conv2d(32, 32, kernel_size=3, padding='same')

        self.conv5x5_1 = nn.Conv2d(in_channels, 32, kernel_size=5, padding='same')
        self.conv5x5_2 = nn.Conv2d(32, 32, kernel_size=3, padding='same')

        self.conv1x1_1 = nn.Conv2d(in_channels, 32, kernel_size=1, padding='same')
        self.conv1x1_2 = nn.Conv2d(32, 32, kernel_size=3, padding='same')

        # encoder 2
        self.conv3x3_1_enc2 = nn.Conv2d(32, 64, kernel_size=3, padding='same')
        self.conv3x3_2_enc2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')

        self.conv5x5_1_enc2 = nn.Conv2d(64, 64, kernel_size=5, padding='same')
        self.conv5x5_2_enc2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')

        self.conv1x1_1_enc2 = nn.Conv2d(64, 64, kernel_size=1, padding='same')
        self.conv1x1_2_enc2 = nn.Conv2d(64, 64, kernel_size=3, padding='same')

        self.down = nn.MaxPool2d(2, 2)

    def forward(self, x):

        # Run x in the encoder 1:
        x1_3 = self.conv3x3_1(x) # to concat with x5_1
        x1_3 = self.conv3x3_2(x1_3)
        skip2 = x1_3
        x1_3_down = self.down(skip2)

        x1_5 = self.conv5x5_1(x)
        x1_5 = self.conv5x5_2(x1_5)

        x1_1 = self.conv1x1_1(x)
        x1_1 = self.conv1x1_2(x1_1)

        x5_1 = torch.cat((x1_5, x1_1), 1) # for skip connection, to concat with x1_3
        
        x5_1_down = self.down(x5_1) # for moving to the next encoder

        skip1 = torch.cat((x1_3, x5_1), 1) # for skip connection, to concat with x3_3

        # Run x in the encoder 2:
        x1_3_enc2 = self.conv3x3_1_enc2(x1_3_down) # to concat with x5_1
        x1_3_enc2 = self.conv3x3_2_enc2(x1_3_enc2)
        skip2_enc2 = x1_3_enc2
        x1_3_down_enc2 = self.down(skip2_enc2)

        x1_5_enc2 = self.conv5x5_1_enc2(x5_1_down)
        x1_5_enc2 = self.conv5x5_2_enc2(x1_5_enc2)

        x1_1_enc2 = self.conv1x1_1_enc2(x5_1_down)
        x1_1_enc2 = self.conv1x1_2_enc2(x1_1_enc2)

        x5_1_enc2 = torch.cat((x1_5_enc2, x1_1_enc2), 1) # for skip connection, to concat with x1_3
        
        x5_1_down_enc2 = self.down(x5_1_enc2) # for moving to the next encoder

        skip1_enc2 = torch.cat((x1_3_enc2, x5_1_enc2), 1) # for skip connection, to concat with x3_3

        return skip1_enc2, skip2_enc2, x1_3_down_enc2, x5_1_down_enc2

stan = STAN()
skip1, skip2, x1_3_down, x5_1_down = stan(input_img)
print("skip1.shape: ", skip1.shape)
print("skip2.shape: ", skip2.shape)
print("x1_3_down.shape: ", x1_3_down.shape)
print("x5_1_down.shape: ", x5_1_down.shape)



skip1.shape:  torch.Size([1, 192, 128, 128])
skip2.shape:  torch.Size([1, 64, 128, 128])
x1_3_down.shape:  torch.Size([1, 64, 64, 64])
x5_1_down.shape:  torch.Size([1, 128, 64, 64])


In [None]:
class StemLayer(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv3_3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding='same')
        self.conv3_3_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.conv1_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same')
        self.conv5_5 = nn.Conv2d(in_channels, out_channels, kernel_size=5,  padding='same')
        self.conv5_1_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)

    def forward(self,x_branche1, x_branche2):
        # Branche 1, 5x5
        y = self.conv5_5(x_branche2)
        x_5x5 = self.conv5_1_3(y)

        # Branche 1, 1x1
        y = self.conv1_1(x_branche2)
        x_1x1 = self.conv5_1_3(y)

        # Branche 2, 3x3
        x_b_1 = self.conv3_3(x_branche1)
        x_3x3_skip = self.conv3_3_3(x_b_1)
        
        x_concat = torch.cat((x_5x5, x_1x1), 1)
        # print("x_concat.shape:", x_concat.shape)

        x_branch1_concat_down = self.downsample(x_concat)
        # print("inp_block_2.shape:", inp_block_2.shape)

        #Concat skip connections
        skip_concat = torch.cat((x_concat, x_b_1), 1)

        x_branch2_down = self.downsample(x_3x3_skip)

        return skip_concat, x_3x3_skip, x_branch1_concat_down, x_branch2_down


class Encoder(nn.Module):
    def __init__(self, in_channels, out_channels, middle=False):
        super().__init__()
        self.conv3_3 = nn.Conv2d(in_channels//2, out_channels//2, kernel_size=3, padding='same')
        self.conv3_3_3 = nn.Conv2d(out_channels//2, out_channels//2, kernel_size=3, padding='same')

        self.conv1_1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding='same')
        self.conv5_5 = nn.Conv2d(in_channels, out_channels, kernel_size=5,  padding='same')
        self.conv5_1_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding='same')

        self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)

        self.middle = middle

    def forward(self, x_branche1, x_branche2):

        print(x_branche1.shape)
        print(x_branche2.shape)

        # Branche 1, 5x5
        y = self.conv5_5(x_branche1)
        x_5x5 = self.conv5_1_3(y)

        # Branche 1, 1x1
        y = self.conv1_1(x_branche1)
        x_1x1 = self.conv5_1_3(y)

        # Branche 2, 3x3
        x_b_2 = self.conv3_3(x_branche2)
        x_3x3_skip = self.conv3_3_3(x_b_2)
        
        if not self.middle:
            x_concat = torch.cat((x_5x5, x_1x1), 1)
            # print("x_concat.shape:", x_concat.shape)

            x_branch1_concat_down = self.downsample(x_concat)
            # print("inp_block_2.shape:", inp_block_2.shape)

            #Concat skip connections
            skip_concat = torch.cat((x_concat, x_b_2), 1)

            x_branch2_down = self.downsample(x_3x3_skip)

            return skip_concat, x_3x3_skip, x_branch1_concat_down, x_branch2_down

        else:
            x = torch.cat((x_5x5, x_1x1, x_3x3_skip), 1)
            return self.downsample(x)

class STAN(nn.Module):
    def __init__(self, in_channels=3, out_channels=1):
        super(STAN, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.filters = [32, 64, 128, 256, 512, 1024]

        self.encoder1 = StemLayer(self.in_channels, self.filters[0])
        self.encoder2 = Encoder(self.filters[1], self.filters[1])
        self.encoder3 = Encoder(self.filters[2], self.filters[2])
        self.encoder4 = Encoder(self.filters[3], self.filters[3])

        self.middle_block = Encoder(self.filters[4], self.filters[4], middle=True)


        self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)

        self.d1_conv1_3x3 = nn.Conv2d(self.filters[4], self.filters[3], kernel_size=1, stride=1, padding=0)

    def forward(self, x):

        skip_concat_enc1, x_3x3_skip_enc1, x_branch1_concat_down, x_branch2_down = self.encoder1(x, x)
        print("x_branch1_concat_down.shape:", x_branch1_concat_down.shape)
        print("x_branch2_down.shape:", x_branch2_down.shape)

        skip_concat_enc2, x_3x3_skip_enc2, x_branch1_concat_down, x_branch2_down = self.encoder2(x_branch1_concat_down, x_branch2_down)
        skip_concat_enc3, x_3x3_skip_enc3, x_branch1_concat_down, x_branch2_down = self.encoder3(x_branch1_concat_down, x_branch2_down)
        skip_concat_enc4, x_3x3_skip_enc4, x_branch1_concat_down, x_branch2_down = self.encoder4(x_branch1_concat_down, x_branch2_down)
        x = self.middle_block(x_branch1_concat_down, x_branch2_down)
        print(x.shape)

        # middle block




        # middle = torch.cat((x_branch1_concat_down, x_branch2_down), 1)
        # print("middle.shape:", middle.shape)

        # up1 = self.upsample(middle)
        # print("up1.shape:", up1.shape)
        # print("x_3x3_skip_enc4.shape:", x_3x3_skip_enc4.shape)
        
        # d1_1 = torch.cat((up1, x_3x3_skip_enc4), 1)
        # print("d1_1.shape:", d1_1.shape)

        # d1_1 = self.d1_conv1_3x3(d1_1)
        # d1_2 = torch.cat((d1_1, skip_concat_enc4), 1)
        # print("d1_2.shape:", d1_2.shape)
        
        

        # Bottleneck shape torch.Size([1, 512, 16, 16])
        # print("Bottleneck shape", x_branch2_down.shape)
        


        return x
    
stan = STAN()
output = stan(input_img)

In [68]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision

class encoder_block(nn.Module):
    def __init__(self, input_channels_1, input_channels_2, output_channels):
        super(encoder_block, self).__init__()
        self.conv5 = nn.Conv2d(in_channels=input_channels_2, out_channels=output_channels, kernel_size=5, padding='same')
        self.conv1 = nn.Conv2d(in_channels=input_channels_2, out_channels=output_channels, kernel_size=1, padding='same')
        self.conv3_1 = nn.Conv2d(in_channels=input_channels_1, out_channels=output_channels, kernel_size=3, padding='same')
        self.conv3_2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding='same')
        self.maxp = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x_1, x_2 = x
        # print("x_1 : ", x_1.shape)
        # print("x_2 : ", x_2.shape)

        x_2_5 = self.conv5(x_2)
        x_2_5 = self.relu(x_2_5)

        x_2_1 = self.conv1(x_2)
        x_2_1 = self.relu(x_2_1)

        x_1_3 = self.conv3_1(x_1)
        x_1_3 = self.relu(x_1_3)

        # print("x_2_5 : ", x_2_5.shape)
        # print("x_2_1 : ", x_2_1.shape)
        # print("x_1_3 : ", x_1_3.shape)

        # sys.exit(0)

        x_2_5_3 = self.conv3_2(x_2_5)
        x_2_5_3 = self.relu(x_2_5_3)

        x_2_1_3 = self.conv3_2(x_2_1)
        x_2_1_3 = self.relu(x_2_1_3)

        x_1_3_3 = self.conv3_2(x_1_3)
        x_1_3_3 = self.relu(x_1_3_3)

        # print("x_2_5_3 : ", x_2_5_3.shape)
        # print("x_2_1_3 : ", x_2_1_3.shape)
        # print("x_1_3_3 : ", x_1_3_3.shape)

        concat = torch.cat((x_2_5_3, x_2_1_3), dim=1)
        # print("concat : ", concat.shape)

        concat_pool = self.maxp(concat)
        x_1_3_3_pool = self.maxp(x_1_3_3)
        skip1 = x_1_3_3
        skip2 = torch.cat((concat, x_1_3), dim=1)

        # print("out1 : ", x_1_3_3_pool.shape)
        # print("out2 : ", concat_pool.shape)
        # print("skip1 : ", skip1.shape)
        # print("skip2 : ", skip2.shape)

        # sys.exit(0)

        return x_1_3_3_pool, concat_pool, skip1, skip2

class middle_block(nn.Module):
    def __init__(self, input_channels_1, input_channels_2, output_channels):
        super(middle_block, self).__init__()
        self.conv5 = nn.Conv2d(in_channels=input_channels_2, out_channels=output_channels, kernel_size=5, padding='same')
        self.conv1 = nn.Conv2d(in_channels=input_channels_2, out_channels=output_channels, kernel_size=1, padding='same')
        self.conv3_1 = nn.Conv2d(in_channels=input_channels_1, out_channels=output_channels, kernel_size=3, padding='same')
        self.conv3_2 = nn.Conv2d(in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding='same')
        self.maxp = nn.MaxPool2d(kernel_size=2)
        self.relu = nn.ReLU()

    def forward(self, x):
        x_1, x_2 = x
        # print("x_1 : ", x_1.shape)
        # print("x_2 : ", x_2.shape)
        # 5x5 conv
        x_2_5 = self.conv5(x_2)
        x_2_5 = self.relu(x_2_5)

        # 1x1 conv
        x_2_1 = self.conv1(x_2)
        x_2_1 = self.relu(x_2_1)

        # 3x3 conv
        x_1_3 = self.conv3_1(x_1)
        x_1_3 = self.relu(x_1_3)

        # print("x_2_5 : ", x_2_5.shape)
        # print("x_2_1 : ", x_2_1.shape)
        # print("x_1_3 : ", x_1_3.shape)

        x_2_5_3 = self.conv3_2(x_2_5)
        x_2_5_3 = self.relu(x_2_5_3)
        x_2_1_3 = self.conv3_2(x_2_1)
        x_2_1_3 = self.relu(x_2_1_3)
        x_1_3_3 = self.conv3_2(x_1_3)
        x_1_3_3 = self.relu(x_1_3_3)
        # print("x_2_5_3 : ", x_2_5_3.shape)
        # print("x_2_1_3 : ", x_2_1_3.shape)
        # print("x_1_3_3 : ", x_1_3_3.shape)
        concat = torch.cat((x_2_5_3, x_2_1_3, x_1_3_3), dim=1)
        # print("concat : ", concat.shape)
        return concat

class decoder_block(nn.Module):
    def __init__(self, input_channels, output_channels):
        super(decoder_block, self).__init__()
        self.deconv3 = nn.ConvTranspose2d(in_channels=input_channels, out_channels=output_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.conv3_1 = nn.Conv2d(in_channels=output_channels*2, out_channels=output_channels, kernel_size=3, padding='same')
        self.conv3_2 = nn.Conv2d(in_channels=output_channels*4, out_channels=output_channels, kernel_size=3, padding='same')
        self.relu = nn.ReLU()

    def forward(self, x):
        x_t, skip1, skip2 = x
        # print("x_t : ", x_t.shape)
        # print("skip1 : ", skip1.shape)
        # print("skip2 : ", skip2.shape)
        x_t_d = self.deconv3(x_t)
        # print("x_t_d : ", x_t_d.shape)
        x_t_d = torch.cat((x_t_d, skip1), dim=1)
        # print("x_t_d : ", x_t_d.shape)
        x_t_d_c = self.conv3_1(x_t_d)
        x_t_d_c = self.relu(x_t_d_c)
        # print("x_t_d_c : ", x_t_d_c.shape)
        x_t_d_c = torch.cat((x_t_d_c, skip2), dim=1)
        # print("x_t_d_c : ", x_t_d_c.shape)
        x_t_d_c_c = self.conv3_2(x_t_d_c)
        x_t_d_c_c = self.relu(x_t_d_c_c)
        # print("x_t_d_c_c : ", x_t_d_c_c.shape)

        return x_t_d_c_c

class stan_architecture(nn.Module):
    def __init__(self, initial_channels=3, filters=[32, 64, 128, 256, 512], final_channels=1):
        super(stan_architecture, self).__init__()
        self.enc1 = encoder_block(initial_channels, initial_channels, filters[0])
        self.enc2 = encoder_block(filters[0], filters[1], filters[1])
        self.enc3 = encoder_block(filters[1], filters[2], filters[2])
        self.enc4 = encoder_block(filters[2], filters[3], filters[3])
        self.mid = middle_block(filters[3], filters[4], filters[4])
        self.dec4 = decoder_block(filters[4]*3, filters[3])
        self.dec3 = decoder_block(filters[3], filters[2])
        self.dec2 = decoder_block(filters[2], filters[1])
        self.dec1 = decoder_block(filters[1], filters[0])
        self.conv = nn.Conv2d(in_channels=filters[0], out_channels=final_channels, kernel_size=3, padding='same')
        self.relu = nn.ReLU()
    
    def forward(self, input_image):
        x = input_image
        print("input: ", x.shape)
        print("********** Encoder 1 **********")
        e1_out1, e1_out2, e1_skip1, e1_skip2 = self.enc1((x, x))
        print("e1_out1", e1_out1.shape)
        print("e1_out2", e1_out2.shape)
        print("e1_skip1", e1_skip1.shape)
        print("e1_skip2", e1_skip2.shape)
        print("********** Encoder 2 **********")
        e2_out1, e2_out2, e2_skip1, e2_skip2 = self.enc2((e1_out1, e1_out2))
        print(e2_out1.shape)
        print(e2_out2.shape)
        print(e2_skip1.shape)
        print(e2_skip2.shape)
        print("********** Encoder 3 **********")
        e3_out1, e3_out2, e3_skip1, e3_skip2 = self.enc3((e2_out1, e2_out2))
        print(e3_out1.shape)
        print(e3_out2.shape)
        print(e3_skip1.shape)
        print(e3_skip2.shape)
        print("********** Encoder 4 **********")
        e4_out1, e4_out2, e4_skip1, e4_skip2 = self.enc4((e3_out1, e3_out2))
        print(e4_out1.shape)
        print(e4_out2.shape)
        print(e4_skip1.shape)
        print(e4_skip2.shape)
        print("********** Middle **********")
        mid_out = self.mid((e4_out1, e4_out2))
        print(mid_out.shape)
        # print("********** Decoder 4 **********")
        d4_out = self.dec4((mid_out, e4_skip1, e4_skip2))
        # print("********** Decoder 3 **********")
        d3_out = self.dec3((d4_out, e3_skip1, e3_skip2))
        # print("********** Decoder 2 **********")
        d2_out = self.dec2((d3_out, e2_skip1, e2_skip2))
        # print("********** Decoder 1 **********")
        d1_out = self.dec1((d2_out, e1_skip1, e1_skip2))
        # print("********** Final **********")
        out = self.conv(d1_out)
        out = self.relu(out)
        # print(d4_out.shape)
        # print(d3_out.shape)
        # print(d2_out.shape)
        # print(d1_out.shape)
        # print("output : ", out.shape)
        return out

model = stan_architecture(3, [32, 64, 128, 256, 512], 1)
x = torch.randn(1, 3, 256, 256, dtype=torch.float, requires_grad=False)
# print(summary(model, (3, 256, 256), batch_size=16))
x = model(x)
print(x.shape)


input:  torch.Size([1, 3, 256, 256])
********** Encoder 1 **********
e1_out1 torch.Size([1, 32, 128, 128])
e1_out2 torch.Size([1, 64, 128, 128])
e1_skip1 torch.Size([1, 32, 256, 256])
e1_skip2 torch.Size([1, 96, 256, 256])
********** Encoder 2 **********
torch.Size([1, 64, 64, 64])
torch.Size([1, 128, 64, 64])
torch.Size([1, 64, 128, 128])
torch.Size([1, 192, 128, 128])
********** Encoder 3 **********
torch.Size([1, 128, 32, 32])
torch.Size([1, 256, 32, 32])
torch.Size([1, 128, 64, 64])
torch.Size([1, 384, 64, 64])
********** Encoder 4 **********
torch.Size([1, 256, 16, 16])
torch.Size([1, 512, 16, 16])
torch.Size([1, 256, 32, 32])
torch.Size([1, 768, 32, 32])
********** Middle **********
torch.Size([1, 1536, 16, 16])
torch.Size([1, 1, 256, 256])
