In [None]:
import os
import uproot
import numpy as np
import scipy.signal as signal
from scipy.signal import savgol_filter
import requests
import zipfile
import matplotlib.pyplot as plt




##### for training####
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
print("module import is done")

In [None]:
from sim2root.Common.IllustrateSimPipe import *
import grand.dataio.root_trees as groot 
print("module import is done")

#### ZHAireS


In [None]:
import matplotlib as mpl
mpl.rcParams['figure.max_open_warning'] = 50

def plot_voltage(directory, nb_event = 25, plot = False):
    time = []
    x_dataset = []
    y_dataset = []
    z_dataset = []
    
    d_input = groot.DataDirectory(directory)
    
    tvoltage_l0 = d_input.tvoltage_l0 
    tshower_l0 = d_input.tshower_l0
    trunefieldsim_l0=d_input.trunefieldsim_l0
    tefield_l0 = d_input.tefield_l0
    trun_l0 = d_input.trun_l0
    
    #get the list of events
    events_list = tvoltage_l0.get_list_of_events()
    nb_events = len(events_list)
    
    print('number of events:',nb_events) 
    
    
    
    # If there are no events in the file, exit
    if nb_events == 0:
        sys.exit("There are no events in the file! Exiting.")
        
    event_counter = 0
    max_events_to_store = nb_event
    ####################################################################################
    # start looping over the events
    ####################################################################################
    previous_run = None    
    
    for event_number,run_number in events_list:
        assert isinstance(event_number, int)
        assert isinstance(run_number, int)
        logger.debug(f"Running event_number: {event_number}, run_number: {run_number}")
        
        if event_counter < max_events_to_store:
            tshower_l0.get_event(event_number, run_number)
            tvoltage_l0.get_event(event_number, run_number)
            tefield_l0.get_event(event_number, run_number)
        
            if previous_run != run_number:                          # load only for new run.
                trun_l0.get_run(run_number)                         # update run info to get site latitude and longitude.       
                trunefieldsim_l0.get_run(run_number)       
                previous_run = run_number
            
        
            trace_voltage = np.asarray(tvoltage_l0.trace, dtype=np.float32) # x,y,z components are stored in events.trace. shape (nb_du, 3, tbins)
            # print("finished storing x,y,z components")
            # print(np.shape(trace_voltage))
            event_counter += 1
        else:
            break
        
        du_id = np.asarray(tefield_l0.du_id) # MT: used for printing info and saving in voltage tree.
        
        
        # t0 calculations
        event_second = tshower_l0.core_time_s
        event_nano = tshower_l0.core_time_ns
        t0_voltage_L0 = (tvoltage_l0.du_seconds-event_second)*1e9  - event_nano + tvoltage_l0.du_nanoseconds 
        
        t_pre_L0 = trunefieldsim_l0.t_pre
        
        #TODO: this forces a homogeneous antenna array.
        trace_shape = trace_voltage.shape
        nb_du = trace_shape[0]
        sig_size = trace_shape[-1]
        logger.info(f"Event has {nb_du} DUs, with a signal size of: {sig_size}")
        
        
        #this gives the indices of the antennas of the array participating in this event
        event_dus_indices = tefield_l0.get_dus_indices_in_run(trun_l0)

        dt_ns_l0 = np.asarray(trun_l0.t_bin_size)[event_dus_indices] # sampling time in ns, sampling freq = 1e9/dt_ns. 
        
        
        # loop over all stations.
        
        zenth = 85  
        for du_idx in range(nb_du):
            logger.debug(f"Running DU number {du_idx}")

                # voltage trace
            trace_voltage_x = trace_voltage[du_idx,0]
            trace_voltage_y = trace_voltage[du_idx,1]
            trace_voltage_z = trace_voltage[du_idx,2]
            trace_voltage_time = np.arange(0,len(trace_voltage_z)) * dt_ns_l0[du_idx] - t_pre_L0
            x_dataset.append(trace_voltage_x)
            y_dataset.append(trace_voltage_y)
            z_dataset.append(trace_voltage_z)
            if plot:
                print("start ploting")
                fig, axs = plt.subplots(1,1, figsize=(8, 6))
                axs.plot(trace_voltage_time, trace_voltage_x, alpha=0.5, label="polarization N")
                axs.plot(trace_voltage_time, trace_voltage_y, alpha=0.5, label="polarization E")
                axs.plot(trace_voltage_time, trace_voltage_z, alpha=0.5, label="polarization v")
                axs.legend()
                axs.set_title(f"voltage antenna {du_idx}")
                axs.set_xlabel("time in ns")
                axs.set_ylabel("voltage in uV")
                # Display or save the plot here, for example:
                plt.show() or plt.savefig("figure.png")
            
                # Now close the figure to free memory
                plt.close(fig)
    
    print("Processing complete for specified number of events!")
    return time, x_dataset , y_dataset, z_dataset
directory = "ZHAireS/sim_Xiaodushan_20221025_220000_RUN0_CD_ZHAireS_0000/" #voltage_29-24992_L0_0000.root        
noised_time, noised_trace_x, noised_trace_y, noised_trace_z = plot_voltage(directory,nb_event = 40, plot = False)
print(f'shape of noised_time:{np.shape(noised_time)}')
print(f'shape of noised_trace_x:{np.shape(noised_trace_x)}')
print(f'shape of noised_trace_y:{np.shape(noised_trace_y)}')
print(f'shape of noised_trace_z:{np.shape(noised_trace_z)}')
        

In [None]:
NJ_directory = "ZHAireS-NJ/sim_Xiaodushan_20221025_220000_RUN0_CD_ZHAireS_0000"
clean_time, clean_trace_x, clean_trace_y, clean_trace_z = plot_voltage(NJ_directory, nb_event = 40, plot = False)
print(f'shape of clean_time:{np.shape(clean_time)}')
print(f'shape of clean_trace_x:{np.shape(clean_trace_x)}')
print(f'shape of clean_trace_y:{np.shape(clean_trace_y)}')
print(f'shape of clean_trace_z:{np.shape(clean_trace_z)}')

In [None]:
from different_metrices import train_and_validate_model
from different_metrices import calculate_psnr_with_peak
from different_metrices import peak_to_peak_ratio
from different_metrices import get_reconstructed_signals
from different_metrices import visualize_denoised_signal
from different_metrices import psnr_loss
from different_metrices import plot_loss_vs_psnr
from different_metrices import plot_metrics



import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import torch.nn.functional as F
import numpy as np
from scipy.signal import hilbert
from skimage.metrics import peak_signal_noise_ratio as psnr
from torch.optim.lr_scheduler import CyclicLR


In [None]:
class ResidualBlock(nn.Module):
    """
    encoder: 2 1-d convolution layers in one block. There are 3 blocks in the encoder.

    decoder: 2 1-d convotranspose layers in one block.  There are 3 blocks in the decoder.
    """
    def __init__(self, in_channels, out_channels, kernel_size, stride, padding):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size, stride=stride, padding=padding)

        # Adjust channels in skip connection if necessary
        self.adjust_channels = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)

        # Apply the skip connection
        if self.adjust_channels is not None:
            identity = self.adjust_channels(identity)

        out += identity
        out = self.relu(out)
        return out
    
class DecoderResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, padding, output_padding):
        super(DecoderResidualBlock, self).__init__()
        self.conv1 = nn.ConvTranspose1d(in_channels, out_channels, kernel_size, stride=2, padding=padding, output_padding=output_padding)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.ConvTranspose1d(out_channels, out_channels, kernel_size, stride=1, padding=padding)

        # Adjust channels in skip connection if necessary
        self.adjust_channels = nn.ConvTranspose1d(in_channels, out_channels, 1, stride=2, output_padding=output_padding) if in_channels != out_channels else None

    def forward(self, x):
        identity = x

        out = self.conv1(x)
        out = self.relu(out)
        out = self.conv2(out)

        # Apply the skip connection
        if self.adjust_channels is not None:
            identity = self.adjust_channels(identity)

        out += identity
        out = self.relu(out)
        return out
    
class Autoencoder(nn.Module):
    def __init__(self, input_size=1024, kernel_size=3):
        super(Autoencoder, self).__init__()
        kernel_size = kernel_size + 1 if kernel_size % 2 == 0 else kernel_size
        padding = kernel_size // 2

        # Encoder with Residual Blocks
        self.encoder = nn.Sequential(
            ResidualBlock(3, 32, kernel_size, stride=1, padding=padding),
            nn.MaxPool1d(2, stride=2),
            ResidualBlock(32, 64, kernel_size, stride=1, padding=padding),
            nn.MaxPool1d(2, stride=2),
            ResidualBlock(64, 128, kernel_size, stride=1, padding=padding),
            nn.MaxPool1d(2, stride=2)
        )

        # Decoder
        self.decoder = nn.Sequential(
            DecoderResidualBlock(128, 64, kernel_size, padding=kernel_size//2, output_padding=1),
            DecoderResidualBlock(64, 32, kernel_size, padding=kernel_size//2, output_padding=1),
            nn.ConvTranspose1d(32, 3, kernel_size, stride=2, padding=kernel_size//2, output_padding=1)
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x


def split_indices(n, train_frac=0.7, valid_frac=0.2):
    """Split indices into training, validation, and test sets."""
    indices = np.arange(n)
    np.random.shuffle(indices)

    train_size = int(n * train_frac)
    valid_size = int(n * valid_frac)

    train_indices = indices[:train_size]
    valid_indices = indices[train_size:train_size + valid_size]
    test_indices = indices[train_size + valid_size:]

    return train_indices, valid_indices, test_indices



class CustomDataset(Dataset):
    def __init__(self, noised_signals, clean_signals, indices=None):
        """
        Args:
            noised_signals: Tuple of lists containing noised X, Y, Z signal components.
            clean_signals: Tuple of lists containing clean X, Y, Z signal components.
            indices: Array-like list of indices specifying which samples to include.
        """
        self.indices = indices if indices is not None else list(range(len(noised_signals[0])))

        # Ensure we access the signals using indices correctly
        self.noised_signals = noised_signals
        self.clean_signals = clean_signals

    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        # Fetch the correct index for the current sample
        actual_idx = self.indices[idx]

        # Properly access the sample data
        noised_signal = np.stack([self.noised_signals[i][actual_idx] for i in range(3)], axis=0)
        clean_signal = np.stack([self.clean_signals[i][actual_idx] for i in range(3)], axis=0)

        

        return torch.tensor(noised_signal, dtype=torch.float32), torch.tensor(clean_signal, dtype=torch.float32)

In [None]:
noised_signals = (noised_trace_x, noised_trace_y, noised_trace_z)
clean_signals = (clean_trace_x, clean_trace_y, clean_trace_z)

total_samples = len(noised_trace_x)  # Assuming the lengths of all signal lists are the same.
train_indices, valid_indices, test_indices = split_indices(total_samples)

train_dataset = CustomDataset(noised_signals, clean_signals, indices=train_indices)
valid_dataset = CustomDataset(noised_signals, clean_signals, indices=valid_indices)
test_dataset = CustomDataset(noised_signals, clean_signals, indices=test_indices)

# Creating DataLoader instances for each dataset
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4,shuffle=False)


In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f'Device set to : {device}')
num_epochs = 1000
base_lr = 0.0001
max_lr = 0.006 
model = Autoencoder(kernel_size=3).to(device)
optimizer = optim.AdamW(model.parameters(), lr=base_lr,)
scheduler = CyclicLR(optimizer, base_lr=base_lr, max_lr=max_lr, 
                     step_size_up=5, step_size_down=20, 
                     mode='triangular', cycle_momentum=False)
criterion = nn.MSELoss()


training_losses, validation_losses, validation_psnr, learning_rates, validation_peak_to_peak = [], [], [], [], []
for epoch in range(num_epochs):
    model.train()
    total_train_loss = 0
    for noisy_data, clean_data in train_loader:  # Correct iteration over DataLoader
        noisy_data, clean_data = noisy_data.to(device), clean_data.to(device)

        optimizer.zero_grad()
        outputs = model(noisy_data)
        # outputs = (outputs * clean_data_var) + clean_data_mean

        loss = criterion(outputs, clean_data)
        loss.backward()
        optimizer.step()
        # scheduler.step()  
        total_train_loss += loss.item()
    
    avg_train_loss = total_train_loss / len(train_loader.dataset)
    training_losses.append(avg_train_loss)

    model.eval()
    total_valid_loss, total_psnr, total_peak_to_peak_ratio = 0, 0, 0
    for noisy_data, clean_data in valid_loader:  # Assuming validation_dataset is a DataLoader
        clean_data, noisy_data = clean_data.to(device), noisy_data.to(device)

        with torch.no_grad():
            outputs = model(noisy_data)
            loss = criterion(outputs, clean_data)
            total_valid_loss += loss.item()
            psnr_value = calculate_psnr_with_peak(clean_data.cpu().numpy(), outputs.cpu().numpy())
            total_psnr += psnr_value
            ratio = peak_to_peak_ratio(clean_data.cpu().numpy(), outputs.cpu().numpy())
            total_peak_to_peak_ratio += ratio

    avg_valid_loss = total_valid_loss / len(valid_loader.dataset)
    validation_losses.append(avg_valid_loss)
    avg_psnr = total_psnr / len(valid_loader.dataset)
    validation_psnr.append(avg_psnr)
    avg_peak_to_peak_ratio = total_peak_to_peak_ratio / len(valid_loader.dataset)
    validation_peak_to_peak.append(avg_peak_to_peak_ratio)
    learning_rates.append(scheduler.get_last_lr()[0])

    print(f'Epoch {epoch+1}/{num_epochs}, Training Loss: {avg_train_loss:.4f}, Validation Loss: {avg_valid_loss:.4f}, Validation PSNR: {avg_psnr:.2f}, Validation Peak-to-Peak: {avg_peak_to_peak_ratio:.2f}, Learning Rate: {learning_rates[-1]:.6f}')
    
epochs = range(1, num_epochs + 1)

plot_metrics(epochs, training_losses, validation_losses, validation_psnr, learning_rates= learning_rates, validation_peak_to_peak = validation_peak_to_peak)
torch.save(model, 'Torch Model')

In [None]:
device = torch.device("cpu")
model = model.to(device)  # Move model to the specified device
model.eval()  # Set model to evaluation mode

with torch.no_grad():  # Disabling gradient calculation
        # Move data to the specified device and ensure you're using the moved data
        for noisy_data, clean_data in test_loader:
            
            noisy_data, clean_data = noisy_data.to(device), clean_data.to(device)
        
            # Now that both data and model are on the same device, perform the forward pass
            denoised_output = model(noisy_data)
            
            # Assuming you only want to plot the first sample in the batch for brevity
            sample_idx = 0  # Index of the sample to plot
            channel_names = ['X Channel', 'Y Channel', 'Z Channel']
            
            for channel_idx in range(3):  # Assuming 3 channels: X, Y, Z
                clean_np = clean_data[sample_idx, channel_idx].cpu().numpy()
                noisy_np = noisy_data[sample_idx, channel_idx].cpu().numpy()
                snr = np.max(clean_np) / np.std(noisy_np - clean_np)
                plt.figure(figsize=(25, 16))  # Set figure size for each channel
                
                # Plot the pure (clean) signal for the current channel
                plt.subplot(2, 1, 1)
                plt.plot(clean_data[sample_idx, channel_idx].cpu(), label=f'Pure - {channel_names[channel_idx]}', color='blue')
                plt.plot(denoised_output[sample_idx, channel_idx].cpu(), label='Denoised signal', linestyle='--', color='orange')
                plt.legend()
                plt.title(f'Denoised vs Pure Signal - {channel_names[channel_idx]}')

                # Plot the noisy signal for the current channel
                plt.subplot(2, 1, 2)
                plt.plot(noisy_data[sample_idx, channel_idx].cpu(), label=f'Noisy signal - {channel_names[channel_idx]}, snr  = {snr}', color='red')
                plt.legend()
                plt.title(f'Noisy Signal - {channel_names[channel_idx]}, snr  = {snr}')

                plt.show()
