In [4]:
import torch
import torch.nn as nn

In [7]:
def double_conv(in_channels, out_channels):
    return nn.Sequential(
        nn.Conv1d(in_channels, out_channels, kernel_size=3), #, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True),
        nn.Conv1d(out_channels, out_channels, kernel_size=3), #, padding=1, padding_mode='zeros'),
        nn.LeakyReLU(inplace=True)
    )  

def crop_tensor(tensor, target_tensor):
     target_size = target_tensor.size()[2]
     tensor_size = tensor.size()[2]
     # assumption: because we crop, tensor_size > target_size
     delta = tensor_size - target_size
     delta = delta // 2
     return tensor[:, :, delta:tensor_size-delta]
 

class Unet(nn.Module):
    def __init__(self):
        super(Unet, self).__init__()
        self.maxpool = nn.MaxPool1d(2) # stide = kernel_size = 2

        self.down_conv1 = double_conv(1, 64) 
        self.down_conv2 = double_conv(64, 128) 
        self.down_conv3 = double_conv(128, 256)
        self.down_conv4 = double_conv(256, 512)
        self.down_conv5 = double_conv(512, 1024)

        self.up_trans1 = nn.ConvTranspose1d(1024, 512, kernel_size=2, stride=2)
        self.up_conv1 = double_conv(1024, 512)
        self.up_trans2 = nn.ConvTranspose1d(512, 256, kernel_size=2, stride=2)
        self.up_conv2 = double_conv(512, 256)
        self.up_trans3 = nn.ConvTranspose1d(256, 128, kernel_size=2, stride=2)
        self.up_conv3 = double_conv(256, 128)
        self.up_trans4 = nn.ConvTranspose1d(128, 64, kernel_size=2, stride=2)
        self.up_conv4 = double_conv(128, 64)
        self.up_trans5 = nn.ConvTranspose1d(64, 1, kernel_size=2, stride=2)
        

    def forward(self, input):
        # batch_size, channels, tensor_size
        # encoder input -> 1, 1, 572
        x1 = self.down_conv1(input)   
        x2 = self.maxpool(x1) 
        x3 = self.down_conv2(x2) 
        x4 = self.maxpool(x3)  
        x5 = self.down_conv3(x4)  
        x6 = self.maxpool(x5)  
        x7 = self.down_conv4(x6)  
        x8 = self.maxpool(x7) 
        x9 = self.down_conv5(x8)  

        # decoder 
        x = self.up_trans1(x9)  
        y = crop_tensor(x7, x)
        x = self.up_conv1(torch.cat([x, y], 1)) 
        x = self.up_trans2(x) 
        y = crop_tensor(x5, x)
        x = self.up_conv2(torch.cat([x, y], 1))
        x = self.up_trans3(x)
        y = crop_tensor(x3, x)
        x = self.up_conv3(torch.cat([x, y], 1))
        x = self.up_trans4(x)
        y = crop_tensor(x1, x)
        x = self.up_conv4(torch.cat([x, y], 1))
        #print(x.size()) 
        return x 
        

In [9]:
x = torch.rand(1, 1, 572)
model = Unet()
print(model(x).size())

torch.Size([1, 64, 388])
