In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import numpy as np

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cpu


In [22]:
class UNetBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size = 3, padding = 1):
        super(UNetBlock1D, self).__init__()
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.bn1 = nn.BatchNorm1d(out_channels)
                
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=kernel_size, padding=padding)
        self.bn2 = nn.BatchNorm1d(out_channels)
        
    def forward(self, x):
        
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        
        return x

In [24]:
UN = UNetBlock1D(1, 1)
input = torch.normal(0, 1, size = (1, 1, 32 ))
UN(input).shape

torch.Size([1, 1, 32])

In [25]:
class DenoisingUNet1D(nn.Module):
    def __init__(self, in_channels, out_channels, hidden_layers, features=64, kernel_size=3, padding=1):
        super(DenoisingUNet1D, self).__init__()
        
        self.hidden_layers = hidden_layers
        self.features = features
        self.encoder_layers = nn.ModuleList()
        self.decoder_layers = nn.ModuleList()
        
        # Encoder        
        self.encoder_layers.append( UNetBlock1D(in_channels, features) )
        for hidden_layer in range(hidden_layers):
            self.encoder_layers.append( UNetBlock1D(features*2**(hidden_layer), features * 2**(hidden_layer +1) ) )
                    
        # Bottleneck layer
        self.flatten = nn.Flatten() # 1, features*2**(layers+1)
        self.bottleneck = nn.Linear(features * 2**(hidden_layers) * features, features * 2**(hidden_layers ) * features )

        # Decoder
        for hidden_layer in range(hidden_layers):
            in_dim = features * 2**(hidden_layers - hidden_layer   )
            out_dim = features * 2**(hidden_layers - hidden_layer - 1)
            self.decoder_layers.append( UNetBlock1D(in_dim + out_dim, out_dim) )
        
        # Final Layer
        self.decoder_layers.append( UNetBlock1D(features, out_channels) )
        

            
    def forward(self, x):
        # Encoding path
        hidden_layers = self.hidden_layers
        features = self.features
        
        skips = []
        for layer in self.encoder_layers:
            x = layer(x)
            skips.append(x)

        
        # Bottleneck
        x = self.flatten(x)
        x = self.bottleneck( x )
        x = F.relu(x)
        x = torch.reshape(x, (1, features*2**(hidden_layers), features)) 

        # Decoding path
        i = 1
        for layer in self.decoder_layers:
            
            x = layer( torch.cat( (x, skips[hidden_layers - i]), 1) )
            
            if i == hidden_layers :
                x = self.decoder_layers[-1](x)
                break
                
            i = i + 1
                    
        return x


In [28]:
DUN = DenoisingUNet1D(1, 1, 4, 16)
DUN

DenoisingUNet1D(
  (encoder_layers): ModuleList(
    (0): UNetBlock1D(
      (conv1): Conv1d(1, 16, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn1): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(16, 16, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn2): BatchNorm1d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): UNetBlock1D(
      (conv1): Conv1d(16, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (2): UNetBlock1D(
      (conv1): Conv1d(32, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kern

In [29]:
input = torch.normal(0, 1, size = (1, 1, 16 ))
DUN(input).shape

torch.Size([1, 1, 16])