In [1]:
import torch
import torch.nn as nn
import torchvision.transforms.functional as TF
import torch.nn.functional as F

class ListSkip(nn.Module):
    def __init__(self,):
        super(ListSkip , self).__init__()
        self.conv1 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        self.conv1d = nn.Conv2d(in_channels=64, out_channels=256, kernel_size=3, stride=4, padding=1)
        self.conv1dd = nn.Conv2d(in_channels=64, out_channels=512, kernel_size=3, stride=8, padding=1)
        self.conv2d = nn.Conv2d(in_channels=128, out_channels=512, kernel_size=3, stride=4, padding=1)

    def add_skip(self, skip_list):
        el1 = skip_list[0]
        el2 = skip_list[1]
        el3 = skip_list[2]
        el4 = skip_list[3]
        # convel1_el2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=1)
        # convel2_el3 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1)
        # convel3_el4 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=2, padding=1)
        el1_conv = self.conv1(el1)
        el1_convd = self.conv1d(el1)
        el1_convdd = self.conv1dd(el1)
        el2_conv = self.conv2(el2)
        el2_conv2d = self.conv2d(el2)
        el3_conv = self.conv3(el3)

        skip_new = [el1 , el1_conv+el2 , el1_convd+el2_conv + el3 ,el1_convdd+el2_conv2d+el3_conv+el4]
        return skip_new
    def forward(self,x):
        return self.add_skip(x)



  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class Attention(nn.Module):
    def __init__(self,):
        super(Attention,self).__init__()
        
    def Attention(self,q,k,v):
        q_reshape = q.view(q.size(0),-1) #(4,2,160,64)
        k_reshape = k.view(k.size(0),-1) #(4,2,160,64)
        v_reshape = v.view(v.size(0),-1) #(4,2,160,64)
        k_transposed = torch.transpose(k_reshape, 1,0) #(4,2,160,64)
        scaled_dot_product = torch.matmul(q_reshape, k_transposed)
        dk = torch.tensor(k_reshape.size()[-1], dtype=torch.float32) #160
        scaled_attention_logits = scaled_dot_product / torch.sqrt(dk) #normalize
        attention_weights = F.softmax(scaled_attention_logits, dim=-1)

        output = torch.matmul(attention_weights, v_reshape)
        output = output.view(q.size(0), q.size(1), q.size(2), q.size(3))

        return attention_weights, output
    def forward(self,q,k,v):
        return self.Attention(q,k,v)

In [None]:

class ReluSIG(nn.Module):
    def __init__(self , embedding_dim):
        super().__init__()
        self.gelu = nn.GELU()
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        Relumoid = self.gelu(x)*self.sigmoid(torch.square(x))
        return Relumoid
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.relusig = ReluSIG(out_channels)
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels,out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            self.relusig,
            nn.Conv2d(out_channels,out_channels, 3, 1, 1,bias=False),
            nn.BatchNorm2d(out_channels),
            self.relusig,
            
        )
    def forward(self,x):
        return self.conv(x)
class UNET(nn.Module):
    def __init__(
            self , in_channels=3, out_channels=1 , features=[64,128,256,512],
    ):
        super(UNET , self).__init__()
        self.ups = nn.ModuleList() #list
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2 , stride=2)
        self.dropout = nn.Dropout(0.3)
        self.att = Attention()
        self.sigmoid = nn.Sigmoid()
        
        
        for feature in features:
            self.downs.append(DoubleConv(in_channels,feature))
            in_channels = feature
        
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2 , feature , kernel_size = 2 , stride = 2,
                )
            )
            self.ups.append(DoubleConv(feature*2,feature))

        self.bottleneck = DoubleConv(features[-1],features[-1]*2)
        self.final_conv = nn.Conv2d(features[0],out_channels,kernel_size=1)
        self.add_skip = ListSkip()
    def forward(self,x):
        skip_connections = []
        for i,down in enumerate(self.downs):
            x = down(x)
            
            skip_connections.append(x) ##adding of previsous layer slip connections
            x = self.pool(x)
        x = self.bottleneck(x)
        out,x = self.att(x,x,x)
        # skip_connections_new = self.add_skip(skip_connections)
        skip_connections_new = self.add_skip(skip_connections)



        skip_connections = skip_connections_new[::-1]

        for idx in range(0,len(self.ups),2):
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2]
            
            if x.shape != skip_connection.shape:
                x = TF.resize(x,size=skip_connection.shape[2:])

            concat_skip = torch.cat((skip_connection,x),dim=1)

            x= self.ups[idx+1](concat_skip)
            

        return self.final_conv(x)


        
def load_model(model_name):
    if model_name == 'UNet':
        model = UNET(in_channels=1, out_channels=1)
    else:
        raise ValueError('Please input valid model name, {} not in model zones.'.format(model_name))
    return model


if __name__ == '__main__':
    model = load_model(model_name='UNet')
    print(model)