In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from Signal_Analyzer import  *


In [3]:
class UNet1D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet1D, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv1d(in_channels, 16, kernel_size=3, padding=1), nn.ReLU())
        self.pool1 = nn.MaxPool1d(2)
        self.encoder2 = nn.Sequential(nn.Conv1d(16, 32, kernel_size=3, padding=1), nn.ReLU())
        self.pool2 = nn.MaxPool1d(2)
        
        self.bottleneck = nn.Sequential(nn.Conv1d(32, 64, kernel_size=3, padding=1), nn.ReLU())
        
        self.upconv2 = nn.ConvTranspose1d(64, 32, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=3, padding=1), nn.ReLU())
        self.upconv1 = nn.ConvTranspose1d(32, 16, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=3, padding=1), nn.ReLU())
        
        self.output_conv = nn.Conv1d(16, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc1p = self.pool1(enc1)
        
        enc2 = self.encoder2(enc1p)
        enc2p = self.pool2(enc2)
        
        # Bottleneck
        bottleneck = self.bottleneck(enc2p)
        
        # Decoder
        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        # Output
        out = self.output_conv(dec1)
        return out

In [5]:
# Example of using UNet1D
# Assuming input signal of length 128
input_signal = torch.randn(1, 1, 128)  # (batch_size, channels, signal_length)
model = UNet1D()

In [6]:
import matplotlib.pyplot as plt

# Plot the output signal
plt.plot(output_signal.squeeze().detach().numpy())
plt.xlabel('Time')
plt.ylabel('Amplitude')
plt.title('Output Signal')
plt.show()
output_signal = model(input_signal)

: 