In [3]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from Signal_Analyzer import  *

In [4]:
class UNet1D(nn.Module):
    def __init__(self, in_channels=1, out_channels=1):
        super(UNet1D, self).__init__()
        self.encoder1 = nn.Sequential(nn.Conv1d(in_channels, 16, kernel_size=3, padding=1), nn.ReLU())
        self.pool1 = nn.MaxPool1d(2)
        self.encoder2 = nn.Sequential(nn.Conv1d(16, 32, kernel_size=3, padding=1), nn.ReLU())
        self.pool2 = nn.MaxPool1d(2)
        
        self.bottleneck = nn.Sequential(nn.Conv1d(32, 64, kernel_size=3, padding=1), nn.ReLU())
        
        self.upconv2 = nn.ConvTranspose1d(64, 32, kernel_size=2, stride=2)
        self.decoder2 = nn.Sequential(nn.Conv1d(64, 32, kernel_size=3, padding=1), nn.ReLU())
        self.upconv1 = nn.ConvTranspose1d(32, 16, kernel_size=2, stride=2)
        self.decoder1 = nn.Sequential(nn.Conv1d(32, 16, kernel_size=3, padding=1), nn.ReLU())
        
        self.output_conv = nn.Conv1d(16, out_channels, kernel_size=3, padding=1)

    def forward(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc1p = self.pool1(enc1)
        
        enc2 = self.encoder2(enc1p)
        enc2p = self.pool2(enc2)
        
        # Bottleneck
        bottleneck = self.bottleneck(enc2p)
        
        # Decoder
        dec2 = self.upconv2(bottleneck)
        dec2 = torch.cat((dec2, enc2), dim=1)
        dec2 = self.decoder2(dec2)
        
        dec1 = self.upconv1(dec2)
        dec1 = torch.cat((dec1, enc1), dim=1)
        dec1 = self.decoder1(dec1)
        
        # Output
        out = self.output_conv(dec1)
        return out

In [5]:
# Initialize the model
model = UNet1D()

# Assuming we have an input signal (e.g., random noise)
input_signal = torch.randn(1, 1, 128)  # (batch_size, channels, signal_length)

# Generate an output signal using the model
output_signal = model(input_signal)



In [None]:
# Convert the output to a numpy array and select the first sample and channel for plotting
output_signal_np = output_signal.detach().numpy()[0, 0, :]

# Plot the generated signal
plt.figure(figsize=(10, 4))
plt.plot(output_signal_np, label='Generated Signal')
plt.title('Generated Signal from UNet1D')
plt.xlabel('Sample Index')
plt.ylabel('Amplitude')
plt.legend()
plt.grid(True)
plt.show()