In [1]:
import os
import time
import random
import pickle
import numpy as np
import pandas as pd
from tqdm import tqdm
from scipy import signal, sparse
from scipy.stats import norm
import matplotlib.pyplot as plt
from matplotlib import cm
import matplotlib.gridspec as gridspec
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import itertools
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.utils.data as data_utils
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torch.nn.parameter import Parameter
import torch.nn.functional as F
from torch import einsum


home_dir='/nas/longleaf/home/kbhimani/'
scratch_dir = '/work/users/k/b/kbhimani/'
eng_peak='fep' # training peak
os.chdir(home_dir+'/CPU-Net')



# Loading CPU-Net and support functions
from tools import (calc_current_amplitude, process_all_waveforms, calculate_tn, check_peak_alignment,
                   get_tail_slope, inf_train_gen, LambdaLR, weights_init_normal, select_quantile, calculate_iou)
from dataset import SplinterDataset, SEQ_LEN, LSPAN, RSPAN
from network import PositionalUNet, RNN

# Check if CUDA is available
cuda_available = torch.cuda.is_available()

# Print whether CUDA is available or not
print(f"CUDA available: {cuda_available}")
# If CUDA is available, print the CUDA device count and device name(s)
if cuda_available:
    print(f"Number of CUDA Devices: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"CUDA Device {i}: {torch.cuda.get_device_name(i)}")
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        
torch.cuda.empty_cache()




CUDA available: True
Number of CUDA Devices: 1
CUDA Device 0: NVIDIA A100-PCIE-40GB MIG 2g.10gb


In [2]:
# BATCH_SIZE = 32 # batch size, each batch is drawn from the infinite train generator
# baseline_len = 200 # number of samples assigned to baseline portions
# rising_edge_len = 250 # number of samples assigned to rising edge
# tail_len = 350 # number of samples assigned to tail 
# baseline_weight=3.0 # weight given to baseline portion of the waveform in loss function
# ris_edge_weight=10.0 # weight giveing to rising edge of the waveform in loss function
# tail_weight=7.0 # weight giving to the RC decay tail of the wavefrom in loss function
# ITERS = 5000 # max number of interations to run
# DECAY = 2500 # iteration at which learning rate starts to decay
# LRATE_Gen =1e-2 # learning rate of the generator
# LRATE_Disc =1e-3 # learning rate of the discriminator
# cyc_loss_weight = 10 # weight of the cycle consistent loss in training, eg loss(sim->data->sim)
# iden_loss_weight = 5 # weight of idenentity loss, for example ATN(data)- data
# gan_loss_weight = 1 # weight of the generator loss. ATN(sim) - data
# max_grad_norm = 40 # Maximum norm for gradient clipping
# w_decay = 1e-3 # weight decay in the optimizers
# n_disc_iters = 25  # Set the number of iterations after which discriminators will be updated
# max_sample = 2e4 # numbers of samples to used for training
BATCH_SIZE = 32 # batch size, each batch is drawn from the infinite train generator
baseline_len = 200 # number of samples assigned to baseline portions
rising_edge_len = 250 # number of samples assigned to rising edge
tail_len = 350 # number of samples assigned to tail 
baseline_weight=3.0 # weight given to baseline portion of the waveform in loss function
ris_edge_weight=10.0 # weight giveing to rising edge of the waveform in loss function
tail_weight=7.0 # weight giving to the RC decay tail of the wavefrom in loss function
ITERS = 7000 # max number of interations to run
DECAY = 2500 # iteration at which learning rate starts to decay
LRATE_Gen =1e-2 # learning rate of the generator
LRATE_Disc =1e-3 # learning rate of the discriminator
cyc_loss_weight = 10 # weight of the cycle consistent loss in training, eg loss(sim->data->sim)
iden_loss_weight = 5 # weight of idenentity loss, for example ATN(data)- data
gan_loss_weight = 1 # weight of the generator loss. ATN(sim) - data
max_grad_norm = 40 # Maximum norm for gradient clipping
w_decay = 1e-3 # weight decay in the optimizers
n_disc_iters = 25  # Set the number of iterations after which discriminators will be updated
max_sample = 2e4 # numbers of samples to used for training
# Check if DECAY is less than ITERS
if DECAY >= ITERS:
    raise ValueError("DECAY must be less than ITERS to avoid division by zero in the learning rate scheduler.")

In [3]:
# sim_pulses = scratch_dir+f'cpu_net_datasets/{eng_peak}_sim_noise_pz.pkl'
# det_pulses = scratch_dir+f'cpu_net_datasets/{eng_peak}_sim_preamp_noise_pz.pkl'
sim_pulses = scratch_dir+f'cpu_net_datasets/{eng_peak}_sim_pz.pkl'
det_pulses = scratch_dir+f'cpu_net_datasets/{eng_peak}_sim_preamp_pz.pkl'

In [4]:
# import numpy as np
# import torch.utils.data as data_utils
# import torch.nn as nn
# import pickle
# import matplotlib.pyplot as plt
# from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
# from tools import calculate_tn
# from tqdm import tqdm
# import random
# from scipy.optimize import curve_fit

# '''
# Parameters for training waveform construction.
# LSPAN: how many sample to select to the left of time point 0 (start of the rise)
# RSPAN: how many sample to select to the right of time point 0 (start of the rise)
# SEQ_LEN: total length of the input pulses, always equal to LSPAN+RSPAN
# '''
# LSPAN=300
# RSPAN=500
# SEQ_LEN=LSPAN+RSPAN
# t_n = 99.9
# base_thres = 0.005 # mean of first 50 smaples should be less than this value
# tail_thres = 0.80 # last 50 samples should be greater than this value
# # chi_squared_threshold= 0.002
# # popt_threshold = -2.6e-4
# norm_tail_height = 0.80
# norm_samples = 5
# class SplinterDataset(Dataset):
#     '''
#     Splinter is the name of our local Ge detector
#     '''
#     def __init__(self, event_dset="DetectorPulses.pickle", siggen_dset="SimulatedPulses.pickle", n_max=1e7, chi_squared_threshold=1, popt_threshold_under=-2, popt_threshold_over=2):
#         self.n_max = n_max
#         self.chi_squared_threshold = chi_squared_threshold
#         self.popt_threshold_over = popt_threshold_over
#         self.popt_threshold_under = popt_threshold_under
#         self.chi_squared_coeff = []
#         self.tau_fits = []
        
#         # Load data and simulation waveforms synchronously
#         self.event_dict, self.siggen_dict = self.event_loader_synchronized(event_dset, siggen_dset)
#         print("Number of Data events:", len(self.event_dict))
#         print("Number of Simulation events:", len(self.siggen_dict))

#         # Set the class attributes for thresholds here
#         self.size = min(len(self.event_dict), len(self.siggen_dict))
#         self.event_ids = [wdict["event"] for wdict in self.siggen_dict]
        
#     def __len__(self):
#         # Return the minimum size between event_dict and siggen_dict to avoid out-of-range errors
#         return min(len(self.event_dict), len(self.siggen_dict))

#     def __getitem__(self, idx):
#         # Use a single simulated waveform based on the index and transform it
#         siggenwf = self.transform(self.siggen_dict[idx]["wf"], self.siggen_dict[idx]["tp0"], sim=True)
#         # Transform the real waveform for comparison or any other purpose
#         real_wf = self.transform(self.event_dict[idx]["wf"], self.event_dict[idx]["tp0"])
#         # Return the real waveform, the single transformed simulated waveform, and the original waveform
#         event_id = self.siggen_dict[idx].get("event", -1)  # Default to -1 or suitable value if not found
#         # Return the event_id as part of the output
#         return real_wf[None, :], siggenwf[None, :], self.event_dict[idx], self.event_dict[idx]
        
#     def return_label(self):
#         return self.trainY
    
#     def set_raw_waveform(self,raw_wf):
#         self.raw_waveform = raw_wf

#     def get_original_waveform(self,wf, input=False):
#         if input:
#             return self.input_transform.recon_waveform(wf)
#         else:
#             return self.output_transform.recon_waveform(wf)
        


#     def normalize_waveform(self, wf):
#         """Normalize waveform to have values between 0 and 1."""
#         min_val = np.min(wf)
#         max_val = np.max(wf)
#         if max_val > min_val:
#             return (wf - min_val) / (max_val - min_val)
#         else:
#             # Handle the case where max_val equals min_val (e.g., constant waveforms)
#             return np.zeros_like(wf)  # or wf * 0 to return a waveform of zeros

#     def transform(self, wf, tp0, sim=False):
#         """Transform waveform by padding based on tp0 and then normalizing."""
#         wf = np.array(wf)
#         # Ensure tp0 is an integer
#         tp0 = int(round(tp0))
#         left_padding = max(LSPAN - tp0, 0)
#         right_padding = max((RSPAN + tp0) - len(wf), 0)
#         # Apply padding
#         wf_padded = np.pad(wf, (left_padding, right_padding), mode='edge')
#         # Adjust tp0 after padding
#         tp0_adjusted = tp0 + left_padding
#         # Slice the waveform around the adjusted tp0 to ensure consistent length
#         wf_sliced = wf_padded[(tp0_adjusted - LSPAN):(tp0_adjusted + RSPAN)]
#         # Normalize the waveform after padding and slicing
#         wf_normalized = self.normalize_waveform(wf_sliced)
#         # Don't normalize if it is sim as it is already normalized
#         return wf_normalized

#     def event_loader_synchronized(self, data_address, sim_address, elow=-99999, ehi=99999):
#         data_wf_list = []
#         sim_wf_list = []
#         count = 0

#         print("Chi squared cut is", self.chi_squared_threshold)
#         print("Tail slope cut over is", self.popt_threshold_over)
#         print("Tail slope cut under is", self.popt_threshold_under)

#         with open(data_address, "rb") as data_file, open(sim_address, "rb") as sim_file:
#             while True:
#                 if count > self.n_max:
#                     break
#                 try:
#                     # Load both data and simulation waveforms
#                     data_wdict = pickle.load(data_file, encoding='latin1')
#                     sim_wdict = pickle.load(sim_file, encoding='latin1')

#                     data_wf = data_wdict["wf"]
#                     sim_wf = sim_wdict["wf"]
                    
                    
#                     data_tp0 = 600 # we know this values for simulations, it same as the padding added
#                     sim_tp0 = 600             

#                     # Transform and check conditions for the data waveform
#                     data_transformed_wf = self.transform(data_wf, data_tp0)
#                     # if len(data_transformed_wf) != SEQ_LEN or np.any(data_transformed_wf[:250] > 0.025) or np.any(data_transformed_wf[-50:] < tail_thres):
#                     #     continue  # Skip this pair if data does not meet criteria

#                     # Transform the simulation waveform
#                     sim_transformed_wf = self.transform(sim_wf, sim_tp0)
                    
#                     # if len(sim_transformed_wf) != SEQ_LEN or np.any(sim_transformed_wf[:250] > 0.025) or np.any(sim_transformed_wf[-50:] < tail_thres):
#                     #     continue  # Skip this pair if sim does not meet criteria
                        
#                     # Append both waveforms if all conditions are met
#                     data_wf_list.append(data_wdict)
#                     sim_wf_list.append(sim_wdict)
#                     count += 1

#                     if count % 10000 == 0:
#                         print(f"{count} waveform pairs loaded.")

#                 except EOFError:
#                     break

#         return data_wf_list, sim_wf_list
    
#     def get_field_from_dict(self, input_dict, fieldname):
#         field_list = []
#         for event in input_dict:
#             field_list.append(event[fieldname])
#         return field_list
    
#     def get_current_amp(self,wf):
#         return max(np.diff(wf.flatten()))
    
#     def plot_waveform(self):
#         num_waveforms_to_plot = 200

#         # Create the first figure for all data waveforms
#         plt.figure(figsize=(10, 5))
#         for i in range(num_waveforms_to_plot):
#             # Fetch the real waveform from the dataset
#             real_wf, sim_wf, _, _ = self.__getitem__(i)
#             # Plot the real data waveform on the figure
#             plt.plot(real_wf[0], linewidth=0.5, label=f'Data Pulse {i+1}' if i < 1 else "")  # Add label only once for legend

#         plt.title(f"{num_waveforms_to_plot} Data Pulses")
#         plt.xlabel("Time Sample [ns]")
#         plt.ylabel("Amplitude")
#         plt.grid(True, which='both', linestyle='--', linewidth=0.5)
#         plt.minorticks_on()
#         plt.grid(which='minor', linestyle=':', linewidth='0.5')
#         plt.legend(loc='upper right')
#         plt.savefig('figs/all_data_pulses.png', dpi=200)
#         plt.show()

#         # Create the second figure for all simulated waveforms
#         plt.figure(figsize=(10, 5))
#         for i in range(num_waveforms_to_plot):
#             # Fetch the corresponding simulated waveform from the dataset
#             real_wf, sim_wf, _, _ = self.__getitem__(i)
#             # Plot the simulated waveform on the figure
#             plt.plot(sim_wf[0], linewidth=0.5, label=f'Simulated Pulse {i+1}' if i < 1 else "")  # Add label only once for legend

#         plt.title(f"{num_waveforms_to_plot} Simulated Pulses")
#         plt.xlabel("Time Sample [ns]")
#         plt.ylabel("Amplitude")
#         plt.grid(True, which='both', linestyle='--', linewidth=0.5)
#         plt.minorticks_on()
#         plt.grid(which='minor', linestyle=':', linewidth='0.5')
#         plt.legend(loc='upper right')
#         plt.savefig('figs/all_simulated_pulses.png', dpi=200)
#         plt.show()

    
#     def linear(self, x, a, b):
#         """Linear function ax + b"""
#         return a * x + b
    
#     def process_wf_log_linear(self, wf):
#         sample = 300
#         if len(wf) < sample:
#             # Return default values if waveform is too short
#             return np.nan, [np.nan, np.nan]  # Ensure popt is a list or array to safely index [0] later
#         x_data = np.arange(sample)
#         y_data = np.log(np.clip(wf[-sample:], 1e-10, None))  # Log of last 300 samples
#         try:
#             popt, pcov = curve_fit(self.linear, x_data, y_data, maxfev=100000)
#             # Calculate residuals and chi-squared for goodness of fit
#             residuals = y_data - self.linear(x_data, *popt)
#             chi_squared = np.sum((residuals ** 2) / self.linear(x_data, *popt))
#         except Exception as e:
#             # Handle fitting errors
#             popt = [np.nan, np.nan]  # Ensure popt is a list or array
#             chi_squared = np.nan
#         return -chi_squared, popt[0] #chi squared would be negative since log of number between 0,1 is negavtive, so we return positive value  
    

In [5]:
import numpy as np
import torch.utils.data as data_utils
import torch.nn as nn
import pickle
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from tools import calculate_tn
from tqdm import tqdm
import random
from scipy.optimize import curve_fit

# Parameters for training waveform construction
LSPAN = 400
RSPAN = 400
SEQ_LEN = LSPAN + RSPAN
t_n = 99.9
base_thres = 0.005  # mean of first 50 samples should be less than this value
tail_thres = 0.78  # last 50 samples should be greater than this value
norm_tail_height = 0.80
norm_samples = 5

class SplinterDataset(Dataset):
    '''Splinter is the name of our local Ge detector'''

    def __init__(self, event_dset="DetectorPulses.pickle", siggen_dset="SimulatedPulses.pickle", n_max=1e5, chi_squared_threshold=1, popt_threshold_under=-2, popt_threshold_over=2):
        self.n_max = n_max
        self.chi_squared_threshold = chi_squared_threshold
        self.popt_threshold_over = popt_threshold_over
        self.popt_threshold_under = popt_threshold_under
        self.chi_squared_coeff = []
        self.tau_fits = []
        self.event_dict, self.siggen_dict = self.event_loader_synchronized(event_dset, siggen_dset)
        print("Number of Data events:", len(self.event_dict))
        print("Number of Simulations events:", len(self.siggen_dict))
        self.size = min(len(self.event_dict), len(self.siggen_dict))
        self.event_ids = [wdict["event"] for wdict in self.siggen_dict]

    def __len__(self):
        # Return the minimum size between event_dict and siggen_dict to avoid out-of-range errors
        return min(len(self.event_dict), len(self.siggen_dict))

    def __getitem__(self, idx):
        # Use a single simulated waveform based on the index and transform it
        siggenwf = self.transform(self.siggen_dict[idx]["wf"], self.siggen_dict[idx]["tp0"], sim=True)
        # Transform the real waveform for comparison or any other purpose
        real_wf = self.transform(self.event_dict[idx]["wf"], self.event_dict[idx]["tp0"])
        event_id = self.siggen_dict[idx].get("event", -1)  # Default to -1 or suitable value if not found
        return real_wf[None, :], siggenwf[None, :], self.event_dict[idx], self.event_dict[idx]

    def normalize_waveform(self, wf):
        """Normalize waveform by dividing by the average of the last norm_samples samples and shifting the waveform 
           so that the average of the first 200 samples is zero."""
        tail_mean = np.mean(wf[-norm_samples:])
        if tail_mean != 0:
            normalized_wf = wf * norm_tail_height / tail_mean
        else:
            normalized_wf = wf  # Avoid division by zero
        first_200_mean = np.mean(normalized_wf[:200])
        normalized_wf = normalized_wf - first_200_mean
        return normalized_wf

    # def normalize_sim_waveform(self, wf):
    #     """Normalize waveform to have values between 0 and 1."""
    #     min_val = np.min(wf)
    #     max_val = np.max(wf)
    #     if max_val > min_val:
    #         return (wf - min_val) / (max_val - min_val)
    #     else:
    #         return np.zeros_like(wf)  # Handle constant waveforms

    def transform(self, wf, tp0, sim=False):
        """Transform waveform by padding based on tp0 and then normalizing."""
        wf = np.array(wf)
        tp0 = int(round(tp0))
        left_padding = max(LSPAN - tp0, 0)
        right_padding = max((RSPAN + tp0) - len(wf), 0)
        wf_padded = np.pad(wf, (left_padding, right_padding), mode='edge')
        tp0_adjusted = tp0 + left_padding
        wf_sliced = wf_padded[(tp0_adjusted - LSPAN):(tp0_adjusted + RSPAN)]
        wf_normalized = self.normalize_waveform(wf_sliced)
        return wf_normalized

    def event_loader_synchronized(self, data_address, sim_address):
        data_wf_list = []
        sim_wf_list = []
        count = 0

        print("Chi squared cut is", self.chi_squared_threshold)
        print("Tail slope cut over is", self.popt_threshold_over)
        print("Tail slope cut under is", self.popt_threshold_under)

        with open(data_address, "rb") as data_file, open(sim_address, "rb") as sim_file:
            while True:
                if count > self.n_max:
                    break
                try:
                    data_wdict = pickle.load(data_file, encoding='latin1')
                    sim_wdict = pickle.load(sim_file, encoding='latin1')

                    data_wf = data_wdict["wf"]
                    sim_wf = sim_wdict["wf"]

                    try:
                        data_tp0 = calculate_tn(data_wf, t_n)
                        sim_tp0 = calculate_tn(sim_wf, t_n)
                    except Exception:
                        continue
                    data_wdict["tp0"] = data_tp0
                    sim_wdict["tp0"] = sim_tp0

                    data_transformed_wf = self.transform(data_wf, data_tp0)
                    sim_transformed_wf = self.transform(sim_wf, sim_tp0)

                    # Apply the same cuts to both data and simulation
                    if self.check_valid_sim_waveform(data_transformed_wf) and self.check_valid_sim_waveform(sim_transformed_wf):
                        data_wf_list.append(data_wdict)
                        sim_wf_list.append(sim_wdict)
                        count += 1

                    if count % 10000 == 0:
                        print(f"{count} waveform pairs loaded.")

                except EOFError:
                    break

        return data_wf_list, sim_wf_list

    def check_valid_waveform(self, wf):
        """Checks if a data waveform is valid based on defined criteria."""
        mean_first_250 = np.mean(wf[:250])
        return (
            len(wf) == SEQ_LEN and
            not np.any(np.isnan(wf)) and
            np.any(wf != 0) and
            np.all(np.array(wf[:250]) <= 0.01) and
            np.all(np.array(wf[-50:]) >= tail_thres) and
            mean_first_250 <= base_thres
        )

    def check_valid_sim_waveform(self, wf):
        """Checks if a simulated waveform is valid based on defined criteria, including no decreases."""
        mean_first_250 = np.mean(wf[:250])
        return (
            len(wf) == SEQ_LEN and
            not np.any(np.isnan(wf)) and
            np.any(wf != 0) and
            np.all(np.array(wf[:250]) <= 0.01) and
            np.all(np.array(wf[-50:]) >= tail_thres) and
            mean_first_250 <= base_thres and
            np.all(np.diff(wf[:400]) >= 0)  # Ensure the waveform never decreases
        )

    def plot_waveform(self):
        num_waveforms_to_plot = 100

        # Create the first figure for all data waveforms
        plt.figure(figsize=(10, 5))
        for i in range(num_waveforms_to_plot):
            real_wf, sim_wf, _, _ = self.__getitem__(i)
            plt.plot(real_wf[0], linewidth=0.5)

        plt.title(f"{num_waveforms_to_plot} Data Pulses")
        plt.xlabel("Time Sample [ns]")
        plt.ylabel("Amplitude")
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
        plt.minorticks_on()
        plt.grid(which='minor', linestyle=':', linewidth='0.5')
        plt.savefig('figs/all_data_pulses.png', dpi=200)
        plt.show()

        # Create the second figure for all simulated waveforms
        plt.figure(figsize=(10, 5))
        for i in range(num_waveforms_to_plot):
            real_wf, sim_wf, _, _ = self.__getitem__(i)
            plt.plot(sim_wf[0], linewidth=0.5)

        plt.title(f"{num_waveforms_to_plot} Simulated Pulses")
        plt.xlabel("Time Sample [ns]")
        plt.ylabel("Amplitude")
        plt.grid(True, which='both', linestyle='--', linewidth=0.5)
        plt.minorticks_on()
        plt.grid(which='minor', linestyle=':', linewidth='0.5')
        plt.savefig('figs/all_simulated_pulses.png', dpi=200)
        plt.show()


In [None]:
dataset = SplinterDataset(det_pulses, sim_pulses, n_max=max_sample-1, chi_squared_threshold=100, popt_threshold_under=-10, popt_threshold_over=10)
validation_split = 0.0
shuffle_dataset = True
random_seed= 42222
indices = np.arange(len(dataset))
if shuffle_dataset :
    np.random.seed(random_seed)
    np.random.shuffle(indices)
split = int(validation_split*len(dataset))
train_indices, val_indices = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_indices)
valid_sampler = SubsetRandomSampler(val_indices)
train_loader = data_utils.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=train_sampler,  drop_last=True)
test_loader = data_utils.DataLoader(dataset, batch_size=BATCH_SIZE, sampler=valid_sampler,  drop_last=True)
data = inf_train_gen(train_loader)

Chi squared cut is 100
Tail slope cut over is 10
Tail slope cut under is -10


In [None]:
dataset.plot_waveform()

In [None]:
'''
This script contains the PositionalUNet network along with 3 candidate discriminators:
* RNN+Attention discriminator
* CNN+PositionalEncoding Discriminator
* Fully Connected Discriminators
we have tested all 3 discriminators, and the RNN+Attention works the best.
'''

import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn.parameter import Parameter
from torch.nn import init
import torch.nn.functional as F
import math
from dataset import SEQ_LEN


class DoubleConv(nn.Module):
    '''
    Double convolutional layer followed by Batch Normalization and LeakyReLU activation function.
    This is used in the U-Net model to perform downsampling and feature extraction.
    '''
    def __init__(self, in_channels, out_channels, mid_channels=None):
        super().__init__()
        if not mid_channels:
            mid_channels = out_channels
        self.double_conv = nn.Sequential(
            nn.Conv1d(in_channels, mid_channels, kernel_size=11, padding=5, bias=False),
            nn.BatchNorm1d(mid_channels),
            nn.LeakyReLU(inplace=True),
            nn.Conv1d(mid_channels, out_channels, kernel_size=7, padding=3, bias=False),
            nn.BatchNorm1d(out_channels),
            nn.LeakyReLU(inplace=True)
        )

    def forward(self, x):
        return self.double_conv(x)


class Down(nn.Module):
    '''
    Downsampling step for U-Net. Uses max pooling followed by double convolution.
    '''
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.maxpool_conv = nn.Sequential(
            nn.MaxPool1d(2),
            DoubleConv(in_channels, out_channels)
        )

    def forward(self, x):
        return self.maxpool_conv(x)


class Up(nn.Module):
    '''
    Upsampling step for U-Net. Uses transposed convolution (or bilinear upsampling) and double convolution.
    '''
    def __init__(self, in_channels, out_channels, bilinear=True):
        super().__init__()
        if bilinear:
            self.up = nn.Upsample(scale_factor=2, mode='linear', align_corners=True)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            self.up = nn.ConvTranspose1d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1, x2):
        x1 = self.up(x1)
        diffY = x2.size()[2] - x1.size()[2]
        x1 = F.pad(x1, [diffY // 2, diffY - diffY // 2])
        x = torch.cat([x2, x1], dim=1)
        return self.conv(x)


class OutConv(nn.Module):
    '''
    Final output convolution layer used to reduce the channels to the desired number of output channels.
    '''
    def __init__(self, in_channels, out_channels):
        super(OutConv, self).__init__()
        self.conv = nn.Sequential(
            torch.nn.Conv1d(in_channels, out_channels, kernel_size=1),
        )

    def forward(self, x):
        return self.conv(x)
    
class PositionalEncoding(nn.Module):
    '''
    Positional encoding for capturing relative positions in the sequence.
    This is useful for incorporating positional information into the U-Net model.
    '''
    def __init__(self, d_model, start=0, dropout=0.1, max_len=10000, factor=1.0):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        self.factor = factor

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(1, 2)
        self.register_buffer('pe', pe)
        self.start = start

    def forward(self, x):
        x = x + self.factor * self.pe[:, :, self.start:(self.start + x.size(2))]
        return self.dropout(x)


class PositionalUNet(nn.Module):
    '''
    U-Net with positional encoding for both encoding and decoding steps.
    This network is designed to process pulse signals and extract high-level features.
    '''
    def __init__(self):
        super(PositionalUNet, self).__init__()
        self.bilinear = True
        multi = 40

        self.inc = DoubleConv(1, multi)
        self.down1 = Down(multi, multi * 2)
        self.down2 = Down(multi * 2, multi * 4)
        self.down3 = Down(multi * 4, multi * 8)
        factor = 2 if self.bilinear else 1
        self.down4 = Down(multi * 8, multi * 16 // factor)

        self.fc_mean = torch.nn.Conv1d(multi * 16 // factor, multi * 16 // factor, 1)
        self.fc_var = torch.nn.Conv1d(multi * 16 // factor, multi * 16 // factor, 1)

        self.up1 = Up(multi * 16, multi * 8 // factor, self.bilinear)
        self.up2 = Up(multi * 8, multi * 4 // factor, self.bilinear)
        self.up3 = Up(multi * 4, multi * 2 // factor, self.bilinear)
        self.up4 = Up(multi * 2, multi // factor, self.bilinear)
        self.outc = OutConv(multi // factor, 1)

        self.pe1 = PositionalEncoding(multi)
        self.pe2 = PositionalEncoding(multi * 2)
        self.pe3 = PositionalEncoding(multi * 4)
        self.pe4 = PositionalEncoding(multi * 8)
        self.pe5 = PositionalEncoding(multi * 16 // factor)
        self.pe6 = PositionalEncoding(multi * 8 // factor, start=multi * 4)
        self.pe7 = PositionalEncoding(multi * 4 // factor, start=multi * 2)
        self.pe8 = PositionalEncoding(multi * 2 // factor, start=multi * 2)
        self.pe9 = PositionalEncoding(multi // factor, start=0, factor=1.0)

    def reparametrize(self, mu, logvar):
        '''
        Reparametrization trick used in variational autoencoders.
        '''
        std = logvar.mul(0.5).exp_()
        eps = torch.randn_like(mu)
        return eps.mul(std).add_(mu)

    def forward(self, x):
        '''
        Forward pass for the positional U-Net model with the reparametrization step.
        '''
        x1 = self.pe1(self.inc(x))
        x2 = self.pe2(self.down1(x1))
        x3 = self.pe3(self.down2(x2))
        x4 = self.pe4(self.down3(x3))
        x5 = self.down4(x4)
        x5 = self.pe5(self.reparametrize(self.fc_mean(x5), self.fc_var(x5)))

        x = self.pe6(self.up1(x5, x4))
        x = self.pe7(self.up2(x, x3))
        x = self.pe8(self.up3(x, x2))
        x = self.up4(x, x1)
        output = self.outc(x)
        return output

import torch
import torch.nn as nn

class RNN(nn.Module):
    '''
    RNN with optional bidirectionality and attention mechanism.
    This network is used as a discriminator for pulse signals.
    '''
    def __init__(self, get_attention=False):
        super(RNN, self).__init__()

        bidirec = True  # Whether to use a bidirectional RNN
        self.bidirec = bidirec
        feed_in_dim = 128
        self.seg = 1  # Segment waveform to reduce its length. Set to 1 for no segmentation.
        self.emb_dim = 64
        self.emb_tick = 1 / 1000.0  # Embedding resolution
        self.embedding = nn.Embedding(1000, self.emb_dim)  # Use range [0, 1000) for embedding
        self.seq_len = (SEQ_LEN - 100) // self.seg  # Use original sequence length minus 100 samples

        # Initialize RNN layers
        if bidirec:
            self.RNNLayer = torch.nn.GRU(input_size=self.emb_dim, hidden_size=feed_in_dim // 2, num_layers=2, batch_first=True, bidirectional=True, dropout=0.2)
            feed_in_dim *= 2
        else:
            self.RNNLayer = torch.nn.GRU(input_size=self.emb_dim, hidden_size=feed_in_dim // 2, num_layers=2, batch_first=True, bidirectional=False, dropout=0.2)

        self.attention_weight = nn.Linear(feed_in_dim // 2, feed_in_dim // 2, bias=False)
        self.norm = torch.nn.BatchNorm1d(feed_in_dim // 2)
        self.get_attention = get_attention

        fc1 = feed_in_dim
        self.fcnet = nn.Linear(fc1, 1)

    def forward(self, x):
        '''
        Forward pass for the RNN with optional attention mechanism.
        '''
        # Skip the first 50 and last 50 points, leaving a sequence length of 700
        x = x[:, :, 50:-50]

        # Reshape input according to segmentation (batch_size, seq_len)
        x = x.view(-1, self.seq_len)

        # Clip values to be within embedding range [0, 1000)
        x = torch.clamp(x / self.emb_tick, 0, 999).long()

        # Embedding lookup
        x = self.embedding(x)

        bsize = x.size(0)  # Batch size
        output, hidden = self.RNNLayer(x)

        # Process hidden state for bidirectional/unidirectional RNN
        if self.bidirec:
            hidden = hidden[-2:]  # Get last hidden state for bidirectional RNN
            hidden = hidden.transpose(0, 1).reshape(bsize, -1)
        else:
            hidden = hidden[-1]  # Get last hidden state for unidirectional RNN

        # Calculate attention scores
        attention_scores = self.calculate_attention_scores(output, hidden)

        if self.get_attention:
            return attention_scores  # Return attention scores if get_attention flag is True

        # Apply attention scores
        context = torch.sum(attention_scores.unsqueeze(-1).expand_as(output) * output, dim=1)

        # Pass through fully connected layer
        x = self.fcnet(torch.cat([context, hidden], dim=-1))

        return torch.sigmoid(x)  # Return sigmoid output (between 0 and 1)

    def calculate_attention_scores(self, output, hidden):
        '''
        Compute attention scores based on the output and hidden states.
        '''
        inner_product = torch.einsum("ijl,il->ij", output, hidden)
        attention_scores = torch.softmax(inner_product, dim=-1)
        return attention_scores

    def get_attention_weights(self, x):
        '''
        Return attention weights explicitly if needed.
        '''
        self.get_attention_flag = True  # Ensure the model returns attention scores
        attention_weights = self.forward(x)
        self.get_attention_flag = False  # Reset the flag
        return attention_weights


In [None]:
class WFDist(nn.Module):
    '''
    Waveform Distance, this is a special type of L1 loss which gives more weight to the
    rising and falling edge of each pulse
    baseline(0,250) rising edge=(250,500), tail=(500,800)

    '''
    def __init__(self, baseline_weight, ris_edge_weight, tail_weight):
        super(WFDist, self).__init__()
        self.criterion = nn.L1Loss()
        self.weight = torch.tensor([baseline_weight]*baseline_len+[ris_edge_weight]*rising_edge_len+[tail_weight]*tail_len).to(DEVICE)

    def forward(self, x1, x2):
        loss_out = 0.0
        for i in range(x1.size(0)):
            loss_out += self.criterion(x1[i].view(-1)*self.weight, x2[i].view(-1)*self.weight)#/self.weight.sum()
        return loss_out/x1.size(0)

In [None]:
target_real = torch.ones(BATCH_SIZE,1).to(DEVICE) #tensor to hold 
target_fake = torch.zeros(BATCH_SIZE,1).to(DEVICE)
netG_A2B = PositionalUNet() # Generator from Data to Simulations (ATN)
netG_B2A = PositionalUNet() # Generator from Simulations to Data (IATN)
netD_A = RNN().apply(weights_init_normal) # Discriminator whose job is to verigy is a pulse looks like data
netD_B = RNN().apply(weights_init_normal) # Discriminator whose job is to verigy is a pulse looks like simulations
netG_A2B.to(DEVICE)
netG_B2A.to(DEVICE)
netD_A.to(DEVICE)
netD_B.to(DEVICE)
criterion_GAN = nn.BCELoss().to(DEVICE)
criterion_cycle = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)
criterion_identity = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)

optimizer_G = torch.optim.Adam(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), lr=LRATE_Gen, betas=(0.5, 0.999), weight_decay=w_decay)
optimizer_D_A = torch.optim.Adam(netD_A.parameters(), lr=LRATE_Disc, betas=(0.5, 0.999), weight_decay=w_decay)
optimizer_D_B = torch.optim.Adam(netD_B.parameters(), lr=LRATE_Disc, betas=(0.5, 0.999), weight_decay=w_decay)


lr_scheduler_G = torch.optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=LambdaLR(ITERS, 0, DECAY).step)
lr_scheduler_D_A = torch.optim.lr_scheduler.LambdaLR(optimizer_D_A, lr_lambda=LambdaLR(ITERS, 0, DECAY).step)
lr_scheduler_D_B = torch.optim.lr_scheduler.LambdaLR(optimizer_D_B, lr_lambda=LambdaLR(ITERS, 0, DECAY).step)


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total trainable parameters in netG_A2B: {count_parameters(netG_A2B)}")
print(f"Total trainable parameters in netG_B2A: {count_parameters(netG_B2A)}")
print(f"Total trainable parameters in netD_A: {count_parameters(netD_A)}")
print(f"Total trainable parameters in netD_B: {count_parameters(netD_B)}")

In [None]:
import os

label_smoothing = 1
# Directory to save the model weights
save_dir = "/work/users/k/b/kbhimani/cpu_net_weights/validation"
os.makedirs(save_dir, exist_ok=True)  # Ensure directory exists

# Function to check for NaN values and stop training if found
def check_for_nan(tensor, name):
    if torch.isnan(tensor).any():
        print(f"NaN detected in {name}. Stopping training.")
        return True
    return False

# Initialize lists to store loss values
losses_G = []
losses_D_A = []
losses_D_B = []
losses_GAN_A2B = []
losses_GAN_B2A = []
losses_identity_A = []
losses_identity_B = []
losses_cycle_ABA = []
losses_cycle_BAB = []
learning_rates_G = []
l1_data_sim = []

for iteration in tqdm(range(ITERS)):
    netG_A2B.train()
    netG_B2A.train()

    #########################
    # A: Detector Pulses
    # B: Simulated Pulses
    #########################

    real_A, real_B = next(data)
    real_A = real_A.to(DEVICE).float()
    real_B = real_B.to(DEVICE).float()

    # Check for NaN values in real pulses
    if check_for_nan(real_A, "real_A") or check_for_nan(real_B, "real_B"):
        break

    ###### Generators A2B and B2A ######
    optimizer_G.zero_grad()

    # Identity loss
    same_B = netG_A2B(real_B)
    if check_for_nan(same_B, "same_B"):
        break
    loss_identity_B = criterion_identity(same_B, real_B) * iden_loss_weight

    same_A = netG_B2A(real_A)
    if check_for_nan(same_A, "same_A"):
        break
    loss_identity_A = criterion_identity(same_A, real_A) * iden_loss_weight

    # GAN loss
    fake_B = netG_A2B(real_A)
    if check_for_nan(fake_B, "fake_B"):
        break
    pred_fake = netD_B(fake_B)
    loss_GAN_A2B = criterion_GAN(pred_fake, target_real * label_smoothing) * gan_loss_weight

    fake_A = netG_B2A(real_B)
    if check_for_nan(fake_A, "fake_A"):
        break
    pred_fake = netD_A(fake_A)
    loss_GAN_B2A = criterion_GAN(pred_fake, target_real * label_smoothing) * gan_loss_weight

    # Cycle loss
    recovered_A = netG_B2A(fake_B)
    if check_for_nan(recovered_A, "recovered_A"):
        break
    loss_cycle_ABA = criterion_cycle(recovered_A, real_A) * cyc_loss_weight

    recovered_B = netG_A2B(fake_A)
    if check_for_nan(recovered_B, "recovered_B"):
        break
    loss_cycle_BAB = criterion_cycle(recovered_B, real_B) * cyc_loss_weight

    l1_sim = F.l1_loss(real_A, fake_A)
    if check_for_nan(l1_sim, "l1_sim"):
        break

    # Total loss for generators
    loss_G = (
        loss_identity_A
        + loss_identity_B
        + loss_cycle_ABA
        + loss_cycle_BAB
        + loss_GAN_A2B
        + loss_GAN_B2A
    )
    if check_for_nan(loss_G, "loss_G"):
        break
    loss_G.backward()

    # Apply gradient clipping for the generators
    torch.nn.utils.clip_grad_norm_(itertools.chain(netG_A2B.parameters(), netG_B2A.parameters()), max_grad_norm)

    optimizer_G.step()

    # Update discriminators every n_disc_iters iterations
    if iteration % n_disc_iters  == 0 : #or iteration>5000
        ###### Discriminator A (Detector Pulses) ######
        optimizer_D_A.zero_grad()

        # Real loss
        pred_real = netD_A(real_A)
        if check_for_nan(pred_real, "pred_real (Discriminator A)"):
            break
        loss_D_real = criterion_GAN(pred_real, target_real * label_smoothing)

        # Fake loss
        pred_fake = netD_A(fake_A.detach())
        if check_for_nan(pred_fake, "pred_fake (Discriminator A)"):
            break
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss for Discriminator A
        loss_D_A = loss_D_real + loss_D_fake
        if check_for_nan(loss_D_A, "loss_D_A"):
            break
        loss_D_A.backward()

        # Apply gradient clipping for Discriminator A
        torch.nn.utils.clip_grad_norm_(netD_A.parameters(), max_grad_norm)

        optimizer_D_A.step()

        ###### Discriminator B (Simulated Pulses) ######
        optimizer_D_B.zero_grad()

        # Real loss
        pred_real = netD_B(real_B)
        if check_for_nan(pred_real, "pred_real (Discriminator B)"):
            break
        loss_D_real = criterion_GAN(pred_real, target_real * label_smoothing)

        # Fake loss
        pred_fake = netD_B(fake_B.detach())
        if check_for_nan(pred_fake, "pred_fake (Discriminator B)"):
            break
        loss_D_fake = criterion_GAN(pred_fake, target_fake)

        # Total loss for Discriminator B
        loss_D_B = loss_D_real + loss_D_fake
        if check_for_nan(loss_D_B, "loss_D_B"):
            break
        loss_D_B.backward()

        # Apply gradient clipping for Discriminator B
        torch.nn.utils.clip_grad_norm_(netD_B.parameters(), max_grad_norm)

        optimizer_D_B.step()

    current_lr_G = lr_scheduler_G.get_last_lr()[0]
    # Append each loss to its corresponding list
    losses_G.append(loss_G.item())
    losses_D_A.append(loss_D_A.item())
    losses_D_B.append(loss_D_B.item())
    losses_GAN_A2B.append(loss_GAN_A2B.item())
    losses_GAN_B2A.append(loss_GAN_B2A.item())
    losses_identity_A.append(loss_identity_A.item())
    losses_identity_B.append(loss_identity_B.item())
    losses_cycle_ABA.append(loss_cycle_ABA.item())
    losses_cycle_BAB.append(loss_cycle_BAB.item())
    learning_rates_G.append(current_lr_G)
    l1_data_sim.append(l1_sim.item())

    lr_scheduler_G.step()
    # lr_scheduler_D_A.step()
    # lr_scheduler_D_B.step()

    # Save model weights every 50 iterations
    if iteration % 50 == 0:
        torch.save(netG_B2A.state_dict(), f"{save_dir}/netG_B2A_iter_{iteration}.pt")
        torch.save(netG_A2B.state_dict(), f"{save_dir}/netG_A2B_iter_{iteration}.pt")
        torch.save(netD_A.state_dict(), f"{save_dir}/netD_A_iter_{iteration}.pth")
        torch.save(netD_B.state_dict(), f"{save_dir}/netD_B_iter_{iteration}.pth")

    # if iteration % 50 == 0:
    #     print(f"Iteration {iteration}: Loss_G: {loss_G.item()}, Loss_D_A: {loss_D_A.item()}, Loss_D_B: {loss_D_B.item()}")

# Save final model weights and loss arrays to disk
torch.save(netG_B2A.state_dict(), f"data_emulation/model_weights/{eng_peak}_ATN.pt")
torch.save(netG_A2B.state_dict(), f"data_emulation/model_weights/{eng_peak}_IATN.pt")
torch.save(netD_A.state_dict(), f"data_emulation/model_weights/{eng_peak}_netD_A.pth")
torch.save(netD_B.state_dict(), f"data_emulation/model_weights/{eng_peak}_netD_B.pth")


# Save loss arrays
np.save("data_emulation/plot_data/losses_G.npy", np.array(losses_G))
np.save("data_emulation/plot_data/losses_D_A.npy", np.array(losses_D_A))
np.save("data_emulation/plot_data/losses_D_B.npy", np.array(losses_D_B))
np.save("data_emulation/plot_data/losses_GAN_A2B.npy", np.array(losses_GAN_A2B))
np.save("data_emulation/plot_data/losses_GAN_B2A.npy", np.array(losses_GAN_B2A))
np.save("data_emulation/plot_data/losses_identity_A.npy", np.array(losses_identity_A))
np.save("data_emulation/plot_data/losses_identity_B.npy", np.array(losses_identity_B))
np.save("data_emulation/plot_data/losses_cycle_ABA.npy", np.array(losses_cycle_ABA))
np.save("data_emulation/plot_data/losses_cycle_BAB.npy", np.array(losses_cycle_BAB))
np.save("data_emulation/plot_data/learning_rates_G.npy", np.array(learning_rates_G))
np.save("data_emulation/plot_data/l1_data_sim.npy", np.array(l1_data_sim))


In [None]:
l1_data_sim = np.load("data_emulation/plot_data/l1_data_sim.npy")
plt.plot(l1_data_sim, label='L1 Loss')
plt.title('L1 Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.minorticks_on()
plt.legend()

In [None]:
import numpy as np
import matplotlib.pyplot as plt

def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

# Load existing losses
losses_G = np.load("data_emulation/plot_data/losses_G.npy")
losses_D_A = np.load("data_emulation/plot_data/losses_D_A.npy")
losses_D_B = np.load("data_emulation/plot_data/losses_D_B.npy")
losses_GAN_A2B = np.load("data_emulation/plot_data/losses_GAN_A2B.npy")
losses_GAN_B2A = np.load("data_emulation/plot_data/losses_GAN_B2A.npy")
losses_identity_A = np.load("data_emulation/plot_data/losses_identity_A.npy")
losses_identity_B = np.load("data_emulation/plot_data/losses_identity_B.npy")
losses_cycle_ABA = np.load("data_emulation/plot_data/losses_cycle_ABA.npy")
losses_cycle_BAB = np.load("data_emulation/plot_data/losses_cycle_BAB.npy")


# Set the moving average window
win = 10
# Create a figure with a grid of 2x3 subplots to include the new loss plots
plt.figure(figsize=(25, 12))
iterations = np.linspace(1,len(losses_G)-win+1,len(losses_G)-win+1)

cut = (iterations>0)

# Discriminator Losses
plt.subplot(2, 2, 1)
plt.plot(iterations[cut], moving_average(losses_D_A, win)[cut], label='Data')
plt.plot(iterations[cut], moving_average(losses_D_B, win)[cut], label='Simulation')
plt.title('Discriminator Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.minorticks_on()
plt.legend()

# Generator Losses
plt.subplot(2, 2, 2)
plt.plot(iterations[cut], moving_average(losses_GAN_A2B, win)[cut], label='Data to Simulation')
plt.plot(iterations[cut], moving_average(losses_GAN_B2A, win)[cut], label='Simulation to Data')
plt.title('Generator Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.minorticks_on()
plt.legend()

# Identity Losses
plt.subplot(2, 2, 3)
plt.plot(iterations, moving_average(losses_identity_A, win), label='Data')
plt.plot(iterations, moving_average(losses_identity_B, win), label='Simulation')
plt.title('Identity Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.minorticks_on()
plt.legend()

# Cycle Consistency Losses
plt.subplot(2, 2, 4)
plt.plot(iterations, moving_average(losses_cycle_ABA, win), label='Data to Simulation to Data')
plt.plot(iterations, moving_average(losses_cycle_BAB, win), label='Simulation to Data to Simulation')
plt.title('Cycle Consistency Losses')
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.yscale('log')
plt.minorticks_on()
plt.legend()

# # Tail Slope Distribution Losses A
# plt.subplot(2, 3, 5)
# plt.plot(moving_average(tail_slope_losses_A, win), label='Tail Slope Loss A')
# plt.title('Tail Slope Distribution Losses A')
# plt.xlabel('Iteration')
# plt.ylabel('Loss')
# plt.yscale('log')
# plt.minorticks_on()
# plt.legend()

# # Tail Slope Distribution Losses B
# plt.subplot(2, 3, 6)
# plt.plot(moving_average(tail_slope_losses_B, win), label='Tail Slope Loss B')
# plt.title('Tail Slope Distribution Losses B')
# plt.xlabel('Iteration')
# plt.ylabel('Loss')
# plt.yscale('log')
# plt.minorticks_on()
# plt.legend()

# Adjust layout and save the figure
plt.tight_layout()
# plt.savefig('figs/loss_funcs_with_tail_slope.pdf')


In [None]:
ATN = PositionalUNet()
ATN.to(DEVICE)
pretrained_dict = torch.load(f'data_emulation/model_weights/{eng_peak}_ATN.pt', weights_only=True)
# pretrained_dict = torch.load('fep_training/ATN_epoch_1.pt')
model_dict = ATN.state_dict()
model_dict.update(pretrained_dict) 
ATN.load_state_dict(pretrained_dict)
ATN.eval()

IATN = PositionalUNet()
IATN.to(DEVICE)
pretrained_dict_inv = torch.load(f'data_emulation/model_weights/{eng_peak}_IATN.pt', weights_only=True)
# pretrained_dict = torch.load('fep_training/ATN_epoch_1.pt')

model_dict_inv = IATN.state_dict()
model_dict_inv.update(pretrained_dict_inv) 
IATN.load_state_dict(pretrained_dict_inv)
IATN.eval()
data_dict_loader = train_loader

In [None]:
wf, wf_deconv, a,b = next(iter(train_loader))
wf = wf.to(DEVICE)
wf_deconv = wf_deconv.to(DEVICE)
outputs  = ATN(wf_deconv)
outputs_inv = IATN(outputs)
iwf = 2 # the ith waveform in the batch to plot
detector_pulse = wf[iwf,0,:].cpu().data.numpy().flatten()
simulated_pulse = wf_deconv[iwf,0,:].cpu().data.numpy().flatten()
translated_pulse = outputs[iwf,0,:].cpu().data.numpy().flatten()
translated_pulse_inv = outputs_inv[iwf,0,:].cpu().data.numpy().flatten()

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
plot_steps_gif=5000
# Directories to save attention weights plots
attention_weights_dir = "/nas/longleaf/home/kbhimani/CPU-Net/data_emulation/giffs/attention_weights"
os.makedirs(attention_weights_dir, exist_ok=True)

def visualize_attention_weights_stacked(model_A, model_B, train_loader, device, num_steps, save_dir):
    # Get a single batch from the train loader
    real_A, real_B, _, _ = next(iter(train_loader))
    
    # Assume real_A and real_B are already on the correct device and normalized if necessary
    real_A = real_A.to(device)
    real_B = real_B.to(device)

    for step in tqdm(range(0, num_steps, 50)):
        # Load saved weights for both models
        model_A.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netD_A_iter_{step}.pth",weights_only=True, map_location=device))
        model_B.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netD_B_iter_{step}.pth",weights_only=True, map_location=device))
        
        model_A.eval()
        model_B.eval()

        with torch.no_grad():
            attention_weights_A = model_A.get_attention_weights(real_A).cpu().numpy()[0]
            attention_weights_B = model_B.get_attention_weights(real_B).cpu().numpy()[0]
        
        # Add padding to match pulse length
        # attention_weights_A_padded = attention_weights_A
        # attention_weights_B_padded = attention_weights_B
        attention_weights_A_padded = np.pad(attention_weights_A, (50, 50), mode='constant')
        attention_weights_B_padded = np.pad(attention_weights_B, (50, 50), mode='constant')

        # Plot the attention weights (stacked vertically)
        fig, axs = plt.subplots(2, 1, figsize=(10, 10))  # Two rows for the attention weights

        time = np.arange(800)

        axs[0].plot(time, attention_weights_A_padded, label="Attention Weights A", color='tab:blue')
        axs[0].set_title(f'Attention Weights for Model A (Step {step})')
        axs[0].set_xlabel('Time Steps')
        axs[0].set_ylabel('Attention Score')
        axs[0].grid(True)

        axs[1].plot(time, attention_weights_B_padded, label="Attention Weights B", color='tab:red')
        axs[1].set_title(f'Attention Weights for Model B (Step {step})')
        axs[1].set_xlabel('Time Steps')
        axs[1].set_ylabel('Attention Score')
        axs[1].grid(True)

        plt.tight_layout()
        plt.savefig(f"{save_dir}/attention_weights_step_{step}.png")
        plt.close()

# Call the function
netD_A = RNN(get_attention=True)
netD_B = RNN(get_attention=True)
visualize_attention_weights_stacked(netD_A, netD_B, train_loader, 'cpu', num_steps=plot_steps_gif, save_dir=attention_weights_dir)


In [None]:
# Cycle ABA and BAB directories
cycle_aba_dir = "/nas/longleaf/home/kbhimani/CPU-Net/data_emulation/giffs/cycle_aba"
os.makedirs(cycle_aba_dir, exist_ok=True)

cycle_bab_dir = "/nas/longleaf/home/kbhimani/CPU-Net/data_emulation/giffs/cycle_bab"
os.makedirs(cycle_bab_dir, exist_ok=True)

def plot_cycle_ABA_through_training(netG_A2B, netG_B2A, train_loader, device, num_steps, save_dir):
    # Get a single batch from the train loader
    real_A, real_B, _, _ = next(iter(train_loader))
    real_A = real_A.to(device)
    real_B = real_B.to(device)

    time = np.linspace(0, 799, 800)

    for step in tqdm(range(0, num_steps, 50)):
        # Load saved weights for both models
        netG_A2B.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netG_A2B_iter_{step}.pt"))
        netG_B2A.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netG_B2A_iter_{step}.pt"))
        
        netG_A2B.eval()
        netG_B2A.eval()

        with torch.no_grad():
            fake_B = netG_A2B(real_A)
            recovered_A = netG_B2A(fake_B)

        # Plot the results
        plt.figure(figsize=(10, 5))
        plt.plot(time, real_A[0, 0, :].cpu().numpy(), label="Target Detector Pulse (Real A)", color="#0072BD", linestyle=":")
        plt.plot(time, fake_B[0, 0, :].cpu().numpy(), label="Translated Pulse (Fake B)", color="#D95319")
        plt.plot(time, recovered_A[0, 0, :].cpu().numpy(), label="Recovered Pulse (Recovered A)", color="#7E2F8E")
        
        plt.title(f"Cycle ABA at Step {step}")
        plt.xlabel("Time [ns]")
        plt.ylabel("Amplitude")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"{save_dir}/cycle_aba_step_{step}.png")
        plt.close()

# Call the function
plot_cycle_ABA_through_training(IATN, ATN, train_loader, DEVICE, num_steps=plot_steps_gif, save_dir=cycle_aba_dir)

def plot_cycle_BAB_through_training(netG_A2B, netG_B2A, train_loader, device, num_steps, save_dir):
    # Get a single batch from the train loader
    real_A, real_B, _, _ = next(iter(train_loader))
    real_A = real_A.to(device)
    real_B = real_B.to(device)

    time = np.linspace(0, 799, 800)
    
    for step in tqdm(range(0, num_steps, 50)):
        # Load saved weights for both models
        netG_A2B.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netG_A2B_iter_{step}.pt"))
        netG_B2A.load_state_dict(torch.load(f"/work/users/k/b/kbhimani/cpu_net_weights/validation/netG_B2A_iter_{step}.pt"))
        
        netG_A2B.eval()
        netG_B2A.eval()

        with torch.no_grad():
            fake_A = netG_B2A(real_B)
            recovered_B = netG_A2B(fake_A)

        # Plot the results
        plt.figure(figsize=(10, 5))
        plt.plot(time, real_B[0, 0, :].cpu().numpy(), label="Target Simulated Pulse (Real B)", color="#0072BD", linestyle=":")
        plt.plot(time, fake_A[0, 0, :].cpu().numpy(), label="Translated Pulse (Fake A)", color="#D95319")
        plt.plot(time, recovered_B[0, 0, :].cpu().numpy(), label="Recovered Pulse (Recovered B)", color="#7E2F8E")
        
        plt.title(f"Cycle BAB at Step {step}")
        plt.xlabel("Time [ns]")
        plt.ylabel("Amplitude")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(f"{save_dir}/cycle_bab_step_{step}.png")
        plt.close()

# Call the function
plot_cycle_BAB_through_training(IATN, ATN, train_loader, DEVICE, num_steps=plot_steps_gif, save_dir=cycle_bab_dir)


In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def visualize_attention_weights_for_both(model_A, model_B, weights_path_A, weights_path_B, test_loader, device):
    # Load the saved weights for both models
    model_A.load_state_dict(torch.load(weights_path_A, map_location=device, weights_only=True))
    model_B.load_state_dict(torch.load(weights_path_B, map_location=device, weights_only=True))
    
    # Set both models to evaluation mode
    model_A.eval()
    model_B.eval()
    
    # Get a single batch from the test loader
    real_A, real_B, a, b = next(iter(test_loader))
    
    # Assume real_A and real_B are already on the correct device and normalized if necessary
    real_A = real_A.to(device)
    real_B = real_B.to(device)

    # Get the attention weights for the single input instance from both models
    with torch.no_grad():
        attention_weights_A = model_A.get_attention_weights(real_A).cpu().numpy()[0]  # Assuming a method that returns attention weights
        attention_weights_B = model_B.get_attention_weights(real_B).cpu().numpy()[0]  # Assuming a method that returns attention weights

    # Add padding of 50 zeros at the start and end for both attention weights
    attention_weights_A_padded = np.pad(attention_weights_A, (50, 50), mode='constant', constant_values=0)
    attention_weights_B_padded = np.pad(attention_weights_B, (50, 50), mode='constant', constant_values=0)
    
    # Set up time intervals for the plot
    time_intervals = [(0, 201, 'tab:blue'), (200, 451, 'tab:red'), (451, 800, 'tab:green')]
    detector_pulses = real_A[0].cpu().numpy()[0]
    simulated_pulses = real_B[0].cpu().numpy()[0]
    
    # Create figure for visualization
    fig, axs = plt.subplots(2, 2, figsize=(20, 10))  # Two rows for waveforms and attention, two columns for each model
    
    # Plot waveform and attention weights for model A (Detector Pulses)
    for start, end, color in time_intervals:
        axs[0, 0].plot(np.arange(start, min(end, len(detector_pulses))), detector_pulses[start:end], color=color, label=f'{start}-{end} ns')
    axs[0, 0].set_title('Detector Pulses')
    axs[0, 0].legend()

    # Plot padded attention weights for detector pulses
    axs[1, 0].plot(attention_weights_A_padded, label='Attention Weights A')
    axs[1, 0].set_title('Attention Weights (Detector Pulses)')
    axs[1, 0].legend()

    # Plot waveform and attention weights for model B (Simulated Pulses)
    for start, end, color in time_intervals:
        axs[0, 1].plot(np.arange(start, min(end, len(simulated_pulses))), simulated_pulses[start:end], color=color, label=f'{start}-{end} ns') 
    axs[0, 1].set_title('Simulated Pulses')
    axs[0, 1].legend()

    # Plot padded attention weights for simulated pulses
    axs[1, 1].plot(attention_weights_B_padded, label='Attention Weights B')
    axs[1, 1].set_title('Attention Weights (Simulated Pulses)')
    axs[1, 1].legend()
    
    # Adjust layout and display the plot
    plt.tight_layout()
    plt.show()
    
netD_A = RNN(get_attention=True)
netD_B = RNN(get_attention=True)
weights_path_A = f'data_emulation/model_weights/{eng_peak}_netD_A.pth'
weights_path_B = f'data_emulation/model_weights/{eng_peak}_netD_B.pth'

# Now call the function to visualize the attention weights for both models
visualize_attention_weights_for_both(netD_A, netD_B, weights_path_A, weights_path_B, train_loader, 'cpu')


In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_cycle_BAB(netG_A2B, netG_B2A, real_A, real_B, DEVICE):
    # Initialize to store one cycle of the waveform transformation for BAB cycle
    time = np.linspace(0, 799, 800)  # Time axis for the waveforms

    # Generate translated and recovered waveforms for cycle BAB
    with torch.no_grad():
        fake_A = netG_B2A(real_B)  # Simulated to data
        recovered_B = netG_A2B(fake_A)  # Data back to simulated

    # Convert tensors to numpy arrays for plotting
    real_Bs = real_B[0, 0, :].cpu().numpy()
    fake_As = fake_A[0, 0, :].cpu().numpy()
    recovered_Bs = recovered_B[0, 0, :].cpu().numpy()
    real_As = real_A[0, 0, :].cpu().numpy()  # Reference target data pulse

    # Plotting the cycle BAB
    plt.figure(figsize=(10, 5))
    plt.plot(time, real_Bs, label='Simulated Pulse (Real B)', color='#0072BD', linestyle='-')
    plt.plot(time, fake_As, label='Translated Pulse (Fake A)', color='#D95319', linestyle='--')
    plt.plot(time, recovered_Bs, label='Recovered Pulse (Recovered B)', color='#7E2F8E', linestyle='-.')
    plt.plot(time, real_As, label='Target Data Pulse (Real A)', color='#4DBEEE', linestyle=':')
    plt.title("Cycle BAB: Simulated -> Data -> Simulated")
    plt.xlabel('Time [ns]')
    plt.ylabel('Normalized Pulses')
    plt.legend(loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    # plt.savefig("figs/result_comp_cycle_BAB.png")
    plt.show()

def plot_cycle_ABA(netG_A2B, netG_B2A, real_A, real_B, DEVICE):
    # Initialize to store one cycle of the waveform transformation for ABA cycle
    time = np.linspace(0, 799, 800)  # Time axis for the waveforms

    # Generate translated and recovered waveforms for cycle ABA
    with torch.no_grad():
        fake_B = netG_A2B(real_A)  # Data to simulated
        recovered_A = netG_B2A(fake_B)  # Simulated back to data

    # Convert tensors to numpy arrays for plotting
    real_As = real_A[0, 0, :].cpu().numpy()
    fake_Bs = fake_B[0, 0, :].cpu().numpy()
    recovered_As = recovered_A[0, 0, :].cpu().numpy()
    real_Bs = real_B[0, 0, :].cpu().numpy()  # Reference target simulated pulse

    # Plotting the cycle ABA
    plt.figure(figsize=(10, 5))
    plt.plot(time, real_As, label='Original Pulse (Real A)', color='#4DBEEE', linestyle='-')
    plt.plot(time, fake_Bs, label='Simulated Pulse (Fake B)', color='#77AC30', linestyle='--')
    plt.plot(time, recovered_As, label='Recovered Pulse (Recovered A)', color='#A2142F', linestyle='-.')
    plt.plot(time, real_Bs, label='Target Simulated Pulse (Real B)', color='#0072BD', linestyle=':')
    plt.title("Cycle ABA: Data -> Simulated -> Data")
    plt.xlabel('Time [ns]')
    plt.ylabel('Normalized Pulses')
    plt.legend(loc='upper left')
    plt.grid(True)
    plt.tight_layout()
    # plt.savefig("figs/result_comp_cycle_ABA.png")
    plt.show()

# Extract a single sample from the train_loader
real_A, real_B, _, _ = next(iter(train_loader))
real_A = real_A.to(DEVICE)
real_B = real_B.to(DEVICE)

# Call the functions using the same pulses for both cycles
plot_cycle_BAB(IATN, ATN, real_A, real_B, DEVICE)
plot_cycle_ABA(IATN, ATN, real_A, real_B, DEVICE)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

def plot_cycle_BAB_with_style(netG_A2B, netG_B2A, train_loader, DEVICE, sample_rate=5, eng_peak="eng_peak", eng_peak_load="eng_peak_load"):
    # Initialize lists to store waveforms
    real_Bs, fake_As, recovered_Bs, real_As = [], [], [], []

    time = np.linspace(0, 799, 800)
    cut = time > 5  

    # Loop over a fixed number of samples
    for _ in range(sample_rate):
        real_A, real_B, a, b = next(iter(train_loader))
        real_A = real_A.to(DEVICE)
        real_B = real_B.to(DEVICE)

        # Generate and recover
        with torch.no_grad():
            fake_A = netG_B2A(real_B)
            recovered_B = netG_A2B(fake_A)

        # Collect waveforms
        real_Bs.append(real_B[0, 0, :].cpu().numpy())
        fake_As.append(fake_A[0, 0, :].cpu().numpy())
        recovered_Bs.append(recovered_B[0, 0, :].cpu().numpy())
        real_As.append(real_A[0, 0, :].cpu().numpy())

       # Plotting
    colors = {
        "real_B": "#0072BD",  # bright blue
        "fake_A": "#D95319",  # bright orange
        "recovered_B": "#7E2F8E",  # purple
    }

    fig, axs = plt.subplots(1, 3, figsize=(20, 5))  
    titles = ["Simulated Pulses", "Translated Pulses", "Recovered Pulses"]
    waveform_lists = [real_Bs, fake_As, recovered_Bs]
    color_keys = ["real_B", "fake_A", "recovered_B"]

    for ax, title, waveforms, color_key in zip(axs, titles, waveform_lists, color_keys):
        for i in range(sample_rate):
            ax.plot(time[cut], waveforms[i][cut], color=colors[color_key], label=title if i == 0 else "")
        ax.set_title(title)
        ax.set_xlabel('Time [ns]')
        ax.set_ylabel('Normlized Pulses')
        # ax.legend(loc='upper right')

    fig.tight_layout()  # Adjust subplots to fit into the figure area.
    plt.savefig("figs/result_comp_1x3_cycle_BAB.png")
    plt.show()
    
plot_cycle_BAB_with_style(IATN, ATN, train_loader, DEVICE, sample_rate=10, eng_peak="SEP", eng_peak_load="DEP")

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch

def plot_cycle_ABA_with_style(netG_A2B, netG_B2A, train_loader, DEVICE, sample_rate=5, eng_peak="eng_peak", eng_peak_load="eng_peak_load"):
    # Initialize lists to store waveforms
    real_As, fake_Bs, recovered_As = [], [], []

    time = np.linspace(0, 799, 800)
    cut = time > 5  

    # Loop over a fixed number of samples
    for _ in range(sample_rate):
        real_A, real_B,a,b = next(iter(train_loader))
        real_A = real_A.to(DEVICE)
        real_B = real_B.to(DEVICE)

        # Generate and recover
        with torch.no_grad():
            fake_B = netG_A2B(real_A)
            recovered_A = netG_B2A(fake_B)

        # Collect waveforms
        real_As.append(real_A[0, 0, :].cpu().numpy())
        fake_Bs.append(fake_B[0, 0, :].cpu().numpy())
        recovered_As.append(recovered_A[0, 0, :].cpu().numpy())

    # Plotting
    colors = {
        "real_A": "#0072BD",  # bright blue
        "fake_B": "#D95319",  # bright orange
        "recovered_A": "#7E2F8E",  # purple
    }

    fig, axs = plt.subplots(1, 3, figsize=(20, 5))
    titles = ["Detector Pulses", "Translated Pulses", "Recovered Pulses"]
    waveform_lists = [real_As, fake_Bs, recovered_As]
    color_keys = ["real_A", "fake_B", "recovered_A"]

    for ax, title, waveforms, color_key in zip(axs, titles, waveform_lists, color_keys):
        for i in range(sample_rate):
            ax.plot(time[cut], waveforms[i][cut], color=colors[color_key], label=title if i == 0 else "")
        ax.set_title(title)
        ax.set_xlabel('Time [ns]')
        ax.set_ylabel('Normlized Pulses')
        # ax.legend(loc='upper right')

    fig.tight_layout()  # Adjust subplots to fit into the figure area.
    plt.savefig("figs/result_comp_1x3_cycle_ABA.png")
    plt.show()

plot_cycle_ABA_with_style(IATN, ATN, train_loader, DEVICE, sample_rate=10, eng_peak="DEP", eng_peak_load="FEP")

In [None]:
print(DEVICE)

In [None]:
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np
import heapq

# Define the WFDist class
class WFDist(nn.Module):
    '''
    Waveform Distance, this is a special type of L1 loss which gives more weight to the
    rising and falling edge of each pulse.
    baseline(0,250) rising edge=(250,500), tail=(500,800)
    '''
    def __init__(self, baseline_weight, ris_edge_weight, tail_weight):
        super(WFDist, self).__init__()
        self.criterion = nn.L1Loss()
        self.weight = torch.tensor(
            [baseline_weight] * baseline_len +
            [ris_edge_weight] * rising_edge_len +
            [tail_weight] * tail_len
        ).to(DEVICE)

    def forward(self, x1, x2):
        loss_out = 0.0
        for i in range(x1.size(0)):
            loss_out += self.criterion(x1[i].view(-1) * self.weight, x2[i].view(-1) * self.weight)
        return loss_out / x1.size(0)

# Function to plot the top 5 waveforms with the highest L1 loss in both categories
def plot_top_waveforms(top_sim, top_translated):
    fig, axs = plt.subplots(2, 5, figsize=(20, 10))

    # Plot top 5 highest L1 loss waveforms for data vs. simulation
    for i, (loss, real_wf, sim_wf) in enumerate(top_sim):
        axs[0, i].plot(real_wf, label='Data')
        axs[0, i].plot(sim_wf, label='Sim')
        axs[0, i].set_title(f'Top {i+1} L1 Loss (Sim): {loss:.4f}')
        axs[0, i].legend()
        axs[0, i].grid(True)

    # Plot top 5 highest L1 loss waveforms for data vs. translated
    for i, (loss, real_wf, trans_wf) in enumerate(top_translated):
        axs[1, i].plot(real_wf, label='Data')
        axs[1, i].plot(trans_wf, label='Translated')
        axs[1, i].set_title(f'Top {i+1} L1 Loss: {loss:.4f}')
        axs[1, i].legend()
        axs[1, i].grid(True)

    plt.tight_layout()
    plt.show()

# Initialize lists to store loss values
atn_loss_list = []  # Loss between data and translated simulated waveforms
iatn_loss_list = []  # Loss between simulated and translated data waveforms

# Set the models to evaluation mode
ATN.eval()
IATN.eval()

# Initialize the loss criterion
criterion_valid = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)

# Disable gradient calculation for evaluation
with torch.no_grad():
    # Iterate over the data loader
    for wf, wf_deconv, rawwf, x in tqdm(train_loader):
        # Move input waveforms to the correct device
        wf = wf.to(DEVICE)  # Real data
        wf_deconv = wf_deconv.to(DEVICE)  # Simulated data

        # Generate data-like waveforms using ATN
        translated_sim = ATN(wf_deconv.float())

        # Generate sim-like waveforms using IATN
        translated_data = IATN(wf.float())

        # Iterate over each waveform in the batch
        for i in range(wf.size(0)):
            real_wf = wf[i, 0]  # Real waveform
            sim_wf = wf_deconv[i, 0]  # Simulated waveform
            trans_sim_wf = translated_sim[i, 0]  # Translated simulated waveform
            trans_data_wf = translated_data[i, 0]  # Translated data waveform

            # Calculate the loss between real data and translated simulated waveform (ATN performance)
            atn_loss = criterion_valid(real_wf.unsqueeze(0), trans_sim_wf.unsqueeze(0))
            atn_loss_list.append((atn_loss.item(), real_wf.cpu().numpy(), trans_sim_wf.cpu().numpy()))

            # Calculate the loss between simulated waveform and translated data waveform (IATN performance)
            iatn_loss = criterion_valid(sim_wf.unsqueeze(0), trans_data_wf.unsqueeze(0))
            iatn_loss_list.append((iatn_loss.item(), sim_wf.cpu().numpy(), trans_data_wf.cpu().numpy()))

# Find the top 5 highest losses for both ATN and IATN
top_atn_losses = heapq.nlargest(5, atn_loss_list, key=lambda x: x[0])
top_iatn_losses = heapq.nlargest(5, iatn_loss_list, key=lambda x: x[0])

# Plot the top 5 waveforms with the highest losses for ATN and IATN
plot_top_waveforms(top_atn_losses, top_iatn_losses)

# Calculate and print overall average losses
average_atn_loss = sum(x[0] for x in atn_loss_list) / len(atn_loss_list) if atn_loss_list else 0
average_iatn_loss = sum(x[0] for x in iatn_loss_list) / len(iatn_loss_list) if iatn_loss_list else 0

print(f"Average WFDist loss between data pulses and ATN-translated simulated pulses: {average_atn_loss}")
print(f"Average WFDist loss between simulated pulses and IATN-translated data pulses: {average_iatn_loss}")

In [None]:
# Extract loss values from the lists
atn_losses = [x[0] for x in atn_loss_list]
iatn_losses = [x[0] for x in iatn_loss_list]

# Define the number of bins for the histograms
num_bins = 25

# Calculate histogram data for ATN losses
atn_hist, atn_bin_edges = np.histogram(atn_losses, bins=num_bins)
atn_top_bins_indices = np.argsort(atn_hist)[-10:]  # Indices of the top 10 most frequent bins

# Calculate histogram data for IATN losses
iatn_hist, iatn_bin_edges = np.histogram(iatn_losses, bins=num_bins)
iatn_top_bins_indices = np.argsort(iatn_hist)[-10:]  # Indices of the top 10 most frequent bins

# Plotting histograms for visualization
plt.figure(figsize=(14, 6))

# Histogram for ATN losses
plt.hist(atn_losses, bins=num_bins, alpha=0.7, label='ATN Losses (Data - Translated Sim)', color='blue')

# Histogram for IATN losses
plt.hist(iatn_losses, bins=num_bins, alpha=0.7, label='IATN Losses (Sim - Translated Data)', color='orange')

# Adding titles and labels
plt.title('Distribution of WFDist Losses for ATN and IATN')
plt.xlabel('WFDist Loss')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()

# Set the number of samples to plot per bin
num_samples_per_bin = 3  # Adjustable variable

# Function to plot a few samples for the top 10 most frequent bins
def plot_samples_from_bins(loss_list, bin_edges, top_bins_indices, label, input_type):
    fig, axs = plt.subplots(5, 2, figsize=(15, 25))
    axs = axs.flatten()
    
    for idx, bin_idx in enumerate(top_bins_indices):
        # Get the bin range
        bin_min = bin_edges[bin_idx]
        bin_max = bin_edges[bin_idx + 1]
        
        # Extract samples within the current bin
        samples_in_bin = [sample for loss, *sample in loss_list if bin_min <= loss < bin_max]
        
        # Plot up to num_samples_per_bin samples from each bin
        for sample in samples_in_bin[:num_samples_per_bin]:
            if input_type == 'ATN':
                # Plot the data pulses that ATN is trying to match
                axs[idx].plot(sample[0], label=f'Real Data (Bin [{bin_min:.2f}, {bin_max:.2f}])')
            elif input_type == 'IATN':
                # Plot the simulated pulses that IATN is trying to match
                axs[idx].plot(sample[0], label=f'Simulated Data (Bin [{bin_min:.2f}, {bin_max:.2f}])')
            axs[idx].legend()
            axs[idx].grid(True)
    
    plt.suptitle(f'Sample Waveforms from Top 10 Most Frequent Bins - {label}')
    plt.tight_layout()
    plt.show()

# Plotting samples from top 10 bins for ATN and IATN
plot_samples_from_bins(atn_loss_list, atn_bin_edges, atn_top_bins_indices, label='ATN Losses', input_type='ATN')
plot_samples_from_bins(iatn_loss_list, iatn_bin_edges, iatn_top_bins_indices, label='IATN Losses', input_type='IATN')


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set the group size for combining time points
group_size = 20  # Group by 20 time points, adjustable variable

# Use only the first 500 entries for debugging purposes
subset_size = len(atn_loss_list)
atn_loss_subset = atn_loss_list[:subset_size]
iatn_loss_subset = iatn_loss_list[:subset_size]

# Extract the waveform differences (absolute losses) for ATN and IATN
atn_diff = [np.abs(real - trans_sim) for _, real, trans_sim in atn_loss_subset]
iatn_diff = [np.abs(sim - trans_data) for _, sim, trans_data in iatn_loss_subset]

# Convert to numpy arrays
atn_diff = np.array(atn_diff)  # Shape: (500, num_time_points)
iatn_diff = np.array(iatn_diff)  # Shape: (500, num_time_points)

# Transpose to group the data by time points
atn_diff_transposed = atn_diff.T  # Shape: (num_time_points, 500)
iatn_diff_transposed = iatn_diff.T  # Shape: (num_time_points, 500)

# Group the data by the specified group size
grouped_atn_diffs = [
    np.mean(atn_diff_transposed[i:i + group_size], axis=0)
    for i in range(0, atn_diff_transposed.shape[0], group_size)
]

grouped_iatn_diffs = [
    np.mean(iatn_diff_transposed[i:i + group_size], axis=0)
    for i in range(0, iatn_diff_transposed.shape[0], group_size)
]

# Plotting box plots for ATN losses grouped by time points
plt.figure(figsize=(20, 10))
plt.boxplot(grouped_atn_diffs, patch_artist=True)
plt.title('Box Plot of ATN Losses Grouped by Time Points (First 500 Samples)')
plt.xlabel('Grouped Time Points')
plt.ylabel('WFDist Loss (Data - Translated Sim)')
plt.grid(True)
plt.tight_layout()
plt.show()

# Plotting box plots for IATN losses grouped by time points
plt.figure(figsize=(20, 10))
plt.boxplot(grouped_iatn_diffs, patch_artist=True)
plt.title('Box Plot of IATN Losses Grouped by Time Points (First 500 Samples)')
plt.xlabel('Grouped Time Points')
plt.ylabel('WFDist Loss (Sim - Translated Data)')
plt.grid(True)
plt.tight_layout()
plt.show()


Code below does a direct comparions of sim verses data and tranlated pulses verses data

In [None]:
class WFDist(nn.Module):
    '''
    Waveform Distance, this is a special type of L1 loss which gives more weight to the
    rising and falling edge of each pulse
    baseline(0,250) rising edge=(250,500), tail=(500,800)

    '''
    def __init__(self, baseline_weight, ris_edge_weight, tail_weight):
        super(WFDist, self).__init__()
        self.criterion = nn.L1Loss()
        self.weight = torch.tensor([baseline_weight]*baseline_len+[ris_edge_weight]*rising_edge_len+[tail_weight]*tail_len).to(DEVICE)

    def forward(self, x1, x2):
        loss_out = 0.0
        for i in range(x1.size(0)):
            loss_out += self.criterion(x1[i].view(-1)*self.weight, x2[i].view(-1)*self.weight)#/self.weight.sum()
        return loss_out/x1.size(0)

import heapq
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

# Function to plot the top 5 waveforms with the highest L1 loss in both categories
def plot_top_waveforms(top_sim, top_translated):
    fig, axs = plt.subplots(2, 5, figsize=(20, 10))

    # Plot top 5 highest L1 loss waveforms for data vs. simulation
    for i, (loss, real_wf, sim_wf) in enumerate(top_sim):
        axs[0, i].plot(real_wf, label='Data')
        axs[0, i].plot(sim_wf, label='Sim')
        axs[0, i].set_title(f'Top {i+1} L1 Loss (Sim): {loss:.4f}')
        axs[0, i].legend()
        axs[0, i].grid(True)

    # Plot top 5 highest L1 loss waveforms for data vs. translated
    for i, (loss, real_wf, trans_wf) in enumerate(top_translated):
        axs[1, i].plot(real_wf, label='Data')
        axs[1, i].plot(trans_wf, label='Translated')
        axs[1, i].set_title(f'Top {i+1} L1 Loss: {loss:.4f}')
        axs[1, i].legend()
        axs[1, i].grid(True)

    plt.tight_layout()
    plt.show()

# Initialize lists to store loss values and waveforms
l1_data_sim = []  # L1 loss between data pulses and simulation pulses
l1_data_translated = []  # L1 loss between data pulses and translated pulses

# Set the model to evaluation mode and disable gradient calculation
ATN.eval()
criterion_valid = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)
with torch.no_grad():
    # Iterate over the data loader
    for wf, wf_deconv, rawwf, x in tqdm(train_loader):
        # Move input waveforms to the correct device
        wf = wf.to(DEVICE)
        wf_deconv = wf_deconv.to(DEVICE)

        # Generate translated waveforms using the model
        gan_wf = ATN(wf_deconv.float())

        # Iterate over each waveform in the batch
        for i in range(wf.size(0)):
            # Get the real waveform, simulated waveform, and translated waveform
            real_wf = wf[i, 0]  # Real waveform tensor on DEVICE
            sim_wf = wf_deconv[i, 0]  # Simulated waveform tensor on DEVICE
            transfer_wf = gan_wf[i, 0]  # Translated waveform tensor on DEVICE

            # Calculate the L1 loss between data and simulation waveforms individually
            l1_sim = criterion_valid(real_wf.unsqueeze(0), sim_wf.unsqueeze(0))
            l1_data_sim.append((l1_sim.item(), real_wf.cpu().numpy(), sim_wf.cpu().numpy()))

            # Calculate the L1 loss between data and translated waveforms individually
            l1_translated = criterion_valid(real_wf.unsqueeze(0), transfer_wf.unsqueeze(0))
            l1_data_translated.append((l1_translated.item(), real_wf.cpu().numpy(), transfer_wf.cpu().numpy()))

# Find the top 5 highest losses using heapq
top_sim = heapq.nlargest(5, l1_data_sim, key=lambda x: x[0])
top_translated = heapq.nlargest(5, l1_data_translated, key=lambda x: x[0])

# Plotting the top 5 waveforms with the highest L1 loss for both categories
plot_top_waveforms(top_sim, top_translated)

# Calculate overall average L1 losses
average_l1_data_sim = sum(x[0] for x in l1_data_sim) / len(l1_data_sim) if l1_data_sim else 0
average_l1_data_translated = sum(x[0] for x in l1_data_translated) / len(l1_data_translated) if l1_data_translated else 0

# Print the average L1 losses
print(f"L1 loss between data pulses and simulation pulses: {average_l1_data_sim}")
print(f"L1 loss between data pulses and translated pulses: {average_l1_data_translated}")

In [None]:
print(abs(average_l1_data_translated-average_l1_data_sim)*100/average_l1_data_sim)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Extract L1 loss values from l1_data_sim and l1_data_translated
l1_sim_losses = [x[0] for x in l1_data_sim]
l1_translated_losses = [x[0] for x in l1_data_translated]

# Define the number of bins for the histograms
num_bins = 50

# Plotting histograms for visualization
plt.figure(figsize=(14, 6))

# Histogram for L1 losses between data and simulated pulses
plt.hist(l1_sim_losses, bins=num_bins, alpha=0.7, label='L1 Losses (Data vs. Sim)', color='blue')

# Histogram for L1 losses between data and translated pulses
plt.hist(l1_translated_losses, bins=num_bins, alpha=0.7, label='L1 Losses (Data vs. Translated)', color='orange')

# Adding titles and labels
plt.title('Distribution of WFDist Losses for Data vs. Sim and Data vs. Translated')
plt.xlabel('WFDist Loss')
plt.ylabel('Frequency')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Set the group size for combining time points
group_size = 10  # You can change this variable to adjust the grouping

# Use only the first 500 entries for debugging purposes
subset_size = 500
l1_data_sim_subset = l1_data_sim[:subset_size]
l1_data_translated_subset = l1_data_translated[:subset_size]

# Extract the L1 loss differences from the subsets
# Calculate the absolute differences at each time point between real and simulated/translated waveforms
diff_sim_subset = [np.abs(real - sim) for _, real, sim in l1_data_sim_subset]  # Differences for simulated waveforms
diff_translated_subset = [np.abs(real - trans) for _, real, trans in l1_data_translated_subset]  # Differences for translated waveforms

# Convert to numpy arrays
diff_sim_subset = np.array(diff_sim_subset)  # Shape: (500, num_time_points)
diff_translated_subset = np.array(diff_translated_subset)  # Shape: (500, num_time_points)

# Calculate the difference between the two losses for each waveform
diff_combined_subset = diff_sim_subset - diff_translated_subset  # Shape: (500, num_time_points)

# Transpose to have time points as columns for box plotting
diff_combined_transposed_subset = diff_combined_subset.T  # Shape: (num_time_points, 500)

# Group the data by the specified group size
grouped_diffs = [
    np.mean(diff_combined_transposed_subset[i:i + group_size], axis=0)
    for i in range(0, diff_combined_transposed_subset.shape[0], group_size)
]

# Plotting box plots for each group of time points
plt.figure(figsize=(20, 10))
plt.boxplot(grouped_diffs, patch_artist=True)
plt.title('Box Plot of L1 Loss Differences Between Simulated and Translated Waveforms (Grouped by 20 Time Points)')
plt.xlabel('Grouped Time Points')
plt.ylabel('L1 Loss Difference (Simulated - Translated)')
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Initialize arrays to store per-time L1 losses
l1_data_sim_time = np.zeros(SEQ_LEN)  # Sum of L1 loss at each time point across all pulses for data vs. simulation
l1_data_translated_time = np.zeros(SEQ_LEN)  # Sum of L1 loss at each time point across all pulses for data vs. translated
count_sim = np.zeros(SEQ_LEN)  # Count of valid pulses at each time point for data vs. simulation
count_translated = np.zeros(SEQ_LEN)  # Count of valid pulses at each time point for data vs. translated

# Iterate over all data points
for loss, real_wf, sim_wf in l1_data_sim:
    per_time_loss = np.abs(real_wf - sim_wf)  # L1 loss at each time point
    l1_data_sim_time += per_time_loss
    count_sim += 1  # Increment count

for loss, real_wf, trans_wf in l1_data_translated:
    per_time_loss = np.abs(real_wf - trans_wf)  # L1 loss at each time point
    l1_data_translated_time += per_time_loss
    count_translated += 1  # Increment count

# Normalize by the count to get average loss per time point
average_l1_data_sim_time = l1_data_sim_time / count_sim
average_l1_data_translated_time = l1_data_translated_time / count_translated

# Plotting the average L1 loss over time (pulse timeline) for both simulations and translations
plt.figure(figsize=(14, 6))
plt.plot(average_l1_data_sim_time, label='Simulated vs Data', color='red', alpha=0.5)
plt.plot(average_l1_data_translated_time, label='Translated vs Data', color='blue', alpha=0.5)
plt.title('Average L1 Loss Over Pulse Time')
plt.xlabel('Time (Pulse Sample Index)')
plt.ylabel('Average L1 Loss')
plt.grid(True)
plt.legend()
plt.show()

# Identifying where the largest errors occur in the pulse timeline
max_loss_indices_sim = np.argsort(-average_l1_data_sim_time)[:10]  # Indices of top 10 highest loss times for sim vs data
max_loss_indices_translated = np.argsort(-average_l1_data_translated_time)[:10]  # Indices of top 10 highest loss times for translated vs data

print("Top 10 time indices with highest L1 loss for Simulated vs Data:", max_loss_indices_sim)
print("Top 10 time indices with highest L1 loss for Translated vs Data:", max_loss_indices_translated)


In [None]:
import heapq
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
import numpy as np

# Define the WFDist class
class WFDist(nn.Module):
    '''
    Waveform Distance, this is a special type of L1 loss which gives more weight to the
    rising and falling edge of each pulse.
    baseline(0,250) rising edge=(250,500), tail=(500,800)
    '''
    def __init__(self, baseline_weight, ris_edge_weight, tail_weight):
        super(WFDist, self).__init__()
        self.criterion = nn.L1Loss()
        self.weight = torch.tensor(
            [baseline_weight] * baseline_len +
            [ris_edge_weight] * rising_edge_len +
            [tail_weight] * tail_len
        ).to(DEVICE)

    def forward(self, x1, x2):
        loss_out = 0.0
        for i in range(x1.size(0)):
            loss_out += self.criterion(x1[i].view(-1) * self.weight, x2[i].view(-1) * self.weight)
        return loss_out / x1.size(0)

# Function to normalize waveform
def normalize_waveform(wf):
    """Normalize waveform by dividing by the average of the last norm_samples samples and shifting the waveform 
       so that the average of the first 200 samples is zero."""
    tail_mean = np.mean(wf[-norm_samples:])
    if tail_mean != 0:
        normalized_wf = wf * norm_tail_height / tail_mean
    else:
        normalized_wf = wf  # Avoid division by zero
    first_200_mean = np.mean(normalized_wf[:200])
    normalized_wf = normalized_wf - first_200_mean
    return normalized_wf

# Function to transform waveform based on tp0
def transform_2(wf):
    """Transform waveform by padding based on tp0 and then normalizing."""
    wf = np.array(wf)
    tp0 = int(round(calculate_tn(real_wf, t_n)))
    left_padding = max(LSPAN - tp0, 0)
    right_padding = max((RSPAN + tp0) - len(wf), 0)
    wf_padded = np.pad(wf, (left_padding, right_padding), mode='edge')
    tp0_adjusted = tp0 + left_padding
    wf_sliced = wf_padded[(tp0_adjusted - LSPAN):(tp0_adjusted + RSPAN)]
    wf_normalized = normalize_waveform(wf_sliced)
    return wf_normalized

# Function to plot the top 5 waveforms with the highest L1 loss in both categories
def plot_top_waveforms(top_sim, top_translated):
    fig, axs = plt.subplots(2, 5, figsize=(20, 10))

    # Plot top 5 highest L1 loss waveforms for data vs. simulation
    for i, (loss, real_wf, sim_wf) in enumerate(top_sim):
        axs[0, i].plot(real_wf, label='Data')
        axs[0, i].plot(sim_wf, label='Sim')
        axs[0, i].set_title(f'Top {i+1} L1 Loss (Sim): {loss:.4f}')
        axs[0, i].legend()
        axs[0, i].grid(True)

    # Plot top 5 highest L1 loss waveforms for data vs. translated
    for i, (loss, real_wf, trans_wf) in enumerate(top_translated):
        axs[1, i].plot(real_wf, label='Data')
        axs[1, i].plot(trans_wf, label='Translated')
        axs[1, i].set_title(f'Top {i+1} L1 Loss: {loss:.4f}')
        axs[1, i].legend()
        axs[1, i].grid(True)

    plt.tight_layout()
    plt.show()

# Initialize lists to store loss values and waveforms
l1_data_sim = []  # L1 loss between data pulses and simulation pulses
l1_data_translated = []  # L1 loss between data pulses and translated pulses

# Set the model to evaluation mode and disable gradient calculation
ATN.eval()
criterion_valid = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)
with torch.no_grad():
    # Iterate over the data loader
    for wf, wf_deconv, rawwf, x in tqdm(train_loader):
        # Move input waveforms to the correct device
        wf = wf.to(DEVICE)
        wf_deconv = wf_deconv.to(DEVICE)

        # Generate translated waveforms using the model
        gan_wf = ATN(wf_deconv.float())

        # Iterate over each waveform in the batch
        for i in range(wf.size(0)):
            # Get the real waveform, simulated waveform, and translated waveform
            real_wf = wf[i, 0].cpu().numpy()  # Real waveform tensor on DEVICE
            sim_wf = wf_deconv[i, 0].cpu().numpy()  # Simulated waveform tensor on DEVICE
            transfer_wf = gan_wf[i, 0].cpu().numpy()  # Translated waveform tensor on DEVICE

            # Transform the waveforms using the transform function
            data_transformed_wf = transform_2(real_wf)
            sim_transformed_wf = transform_2(sim_wf)
            translated_transformed_wf = transform_2(transfer_wf)

            # Convert the transformed waveforms back to tensors for loss calculation
            data_transformed_tensor = torch.tensor(data_transformed_wf).to(DEVICE).unsqueeze(0)
            sim_transformed_tensor = torch.tensor(sim_transformed_wf).to(DEVICE).unsqueeze(0)
            translated_transformed_tensor = torch.tensor(translated_transformed_wf).to(DEVICE).unsqueeze(0)

            # Calculate the L1 loss between data and simulation waveforms individually
            l1_sim = criterion_valid(data_transformed_tensor, sim_transformed_tensor)
            l1_data_sim.append((l1_sim.item(), data_transformed_wf, sim_transformed_wf))

            # Calculate the L1 loss between data and translated waveforms individually
            l1_translated = criterion_valid(data_transformed_tensor, translated_transformed_tensor)
            l1_data_translated.append((l1_translated.item(), data_transformed_wf, translated_transformed_wf))

# Find the top 5 highest losses using heapq
top_sim = heapq.nlargest(5, l1_data_sim, key=lambda x: x[0])
top_translated = heapq.nlargest(5, l1_data_translated, key=lambda x: x[0])

# Plotting the top 5 waveforms with the highest L1 loss for both categories
plot_top_waveforms(top_sim, top_translated)

# Calculate overall average L1 losses
average_l1_data_sim = sum(x[0] for x in l1_data_sim) / len(l1_data_sim) if l1_data_sim else 0
average_l1_data_translated = sum(x[0] for x in l1_data_translated) / len(l1_data_translated) if l1_data_translated else 0

# Print the average L1 losses
print(f"L1 loss between data pulses and simulation pulses: {average_l1_data_sim}")
print(f"L1 loss between data pulses and translated pulses: {average_l1_data_translated}")


In [None]:
# import torch
# import torch.nn as nn
# import matplotlib.pyplot as plt
# from tqdm import tqdm
# import numpy as np

# class WFDist(nn.Module):
#     '''
#     Waveform Distance: A special L1 loss that gives more weight to the rising and falling edges of each pulse
#     baseline(0,250), rising edge=(250,500), tail=(500,800)
#     '''
#     def __init__(self, baseline_weight, ris_edge_weight, tail_weight):
#         super(WFDist, self).__init__()
#         self.criterion = nn.L1Loss(reduction='none')  # 'none' to manually apply weights
#         # Save the initial weight distribution without setting the length
#         self.baseline_weight = baseline_weight
#         self.ris_edge_weight = ris_edge_weight
#         self.tail_weight = tail_weight

#     def forward(self, x1, x2):
#         # Dynamically create the weight tensor based on the current length of the input
#         length = x1.view(-1).size(0)
#         baseline_len = min(250, length)  # Ensure that baseline_len doesn't exceed the input size
#         ris_edge_len = min(500, length - baseline_len)
#         tail_len = max(0, length - baseline_len - ris_edge_len)
#         weight = torch.tensor(
#             [self.baseline_weight] * baseline_len +
#             [self.ris_edge_weight] * ris_edge_len +
#             [self.tail_weight] * tail_len
#         ).to(x1.device)
        
#         # Compute weighted L1 loss
#         loss_out = self.criterion(x1.view(-1), x2.view(-1)) * weight
#         return loss_out.sum() / weight.sum()  # Normalize by the sum of weights

# def smooth_waveform(wf, window_size=5):
#     """Apply a simple moving average filter to smooth the waveform."""
#     return np.convolve(wf, np.ones(window_size) / window_size, mode='same')

# # Function to calculate L1 loss between data pulses, simulation pulses, and translated pulses
# def calculate_l1_loss(data_loader, ATN, DEVICE):
#     l1_data_sim = []  # L1 loss between data pulses and simulation pulses
#     l1_data_translated = []  # L1 loss between data pulses and translated pulses

#     # Set the model to evaluation mode and disable gradient calculation
#     ATN.eval()
#     criterion_valid = WFDist(baseline_weight, ris_edge_weight, tail_weight).to(DEVICE)
#     window_size = 5  # Define window size used for smoothing
    
#     with torch.no_grad():
#         # Iterate over the data loader
#         for wf, wf_deconv, rawwf, x in tqdm(data_loader):
#             # Move input waveforms to the correct device
#             wf = wf.to(DEVICE)
#             wf_deconv = wf_deconv.to(DEVICE)
            
#             # Generate translated waveforms using the model
#             gan_wf = ATN(wf_deconv.float())
            
#             # Iterate over each waveform in the batch
#             for iwf in range(wf.size(0)):
#                 # Get the real waveform, simulated waveform, and translated waveform
#                 real_wf = wf[iwf, 0].cpu().numpy()  # Real waveform tensor on DEVICE
#                 sim_wf = wf_deconv[iwf, 0].cpu().numpy()  # Simulated waveform tensor on DEVICE
#                 transfer_wf = gan_wf[iwf, 0].cpu().numpy()  # Translated waveform tensor on DEVICE

#                 # Apply smoothing to the waveforms
#                 smooth_real_wf = smooth_waveform(real_wf, window_size)
#                 smooth_sim_wf = smooth_waveform(sim_wf, window_size)
#                 smooth_transfer_wf = smooth_waveform(transfer_wf, window_size)

#                 # Trim the last window_size samples to avoid edge effects
#                 trim_len = window_size
#                 smooth_real_wf = smooth_real_wf[:-trim_len]
#                 smooth_sim_wf = smooth_sim_wf[:-trim_len]
#                 smooth_transfer_wf = smooth_transfer_wf[:-trim_len]

#                 # Convert smoothed waveforms back to tensors and move them to the correct device
#                 smooth_real_wf = torch.tensor(smooth_real_wf, device=DEVICE)
#                 smooth_sim_wf = torch.tensor(smooth_sim_wf, device=DEVICE)
#                 smooth_transfer_wf = torch.tensor(smooth_transfer_wf, device=DEVICE)

#                 # Calculate the L1 loss between smoothed data and simulation waveforms
#                 l1_sim = criterion_valid(smooth_real_wf, smooth_sim_wf)
#                 l1_data_sim.append((l1_sim.item(), smooth_real_wf.cpu().numpy(), smooth_sim_wf.cpu().numpy()))

#                 # Calculate the L1 loss between smoothed data and translated waveforms
#                 l1_translated = criterion_valid(smooth_real_wf, smooth_transfer_wf)
#                 l1_data_translated.append((l1_translated.item(), smooth_real_wf.cpu().numpy(), smooth_transfer_wf.cpu().numpy()))


#     # Plot an example of smoothed vs. original waveforms
#     plot_example_smoothing(real_wf, smooth_real_wf.cpu().numpy())

#     # Sort the losses to get the top 5 highest losses for both categories
#     l1_data_sim.sort(reverse=True, key=lambda x: x[0])
#     l1_data_translated.sort(reverse=True, key=lambda x: x[0])

#     # Extract the top 5 waveforms with the highest losses
#     top_sim = l1_data_sim[:5]
#     top_translated = l1_data_translated[:5]

#     # Plotting the top 5 waveforms with the highest L1 loss for both categories
#     plot_top_waveforms(top_sim, top_translated)

#     # Calculate overall average L1 losses
#     average_l1_data_sim = sum(x[0] for x in l1_data_sim) / len(l1_data_sim) if l1_data_sim else 0
#     average_l1_data_translated = sum(x[0] for x in l1_data_translated) / len(l1_data_translated) if l1_data_translated else 0
    
#     return average_l1_data_sim, average_l1_data_translated

# def plot_example_smoothing(original_wf, smoothed_wf):
#     """
#     Plot an example of an original waveform and its smoothed version.
#     """
#     plt.figure(figsize=(10, 5))
#     plt.plot(original_wf, label='Original Waveform', color='blue', linewidth=0.7)
#     plt.plot(smoothed_wf, label='Smoothed Waveform', color='red', linestyle='--', linewidth=1)
#     plt.title('Example of Smoothing on a Waveform')
#     plt.xlabel('Time Sample [ns]')
#     plt.ylabel('Amplitude')
#     plt.grid(True, which='both', linestyle='--', linewidth=0.5)
#     plt.minorticks_on()
#     plt.legend()
#     plt.show()

# def plot_top_waveforms(top_sim, top_translated):
#     """
#     Function to plot the top 5 waveforms with the highest L1 loss in both categories.
#     """
#     fig, axs = plt.subplots(2, 5, figsize=(20, 10))

#     # Plot top 5 highest L1 loss waveforms for data vs. simulation
#     for i, (loss, real_wf, sim_wf) in enumerate(top_sim):
#         axs[0, i].plot(real_wf, label='Data')
#         axs[0, i].plot(sim_wf, label='Sim')
#         axs[0, i].set_title(f'Top {i+1} L1 Loss (Sim): {loss:.4f}')
#         axs[0, i].legend()
#         axs[0, i].grid(True)

#     # Plot top 5 highest L1 loss waveforms for data vs. translated
#     for i, (loss, real_wf, trans_wf) in enumerate(top_translated):
#         axs[1, i].plot(real_wf, label='Data')
#         axs[1, i].plot(trans_wf, label='Translated')
#         axs[1, i].set_title(f'Top {i+1} L1 Loss (Translated): {loss:.4f}')
#         axs[1, i].legend()
#         axs[1, i].grid(True)

#     plt.tight_layout()
#     plt.show()

# # Call the function with your train_loader
# average_l1_data_sim, average_l1_data_translated = calculate_l1_loss(train_loader, ATN, DEVICE)

# # Print the average L1 losses
# print(f"L1 loss between data pulses and simulation pulses: {average_l1_data_sim}")
# print(f"L1 loss between data pulses and translated pulses: {average_l1_data_translated}")


In [None]:
# import torch
# import torch.nn.functional as F
# from tqdm import tqdm

# # Function to calculate MSE between data pulses, simulation pulses, and translated pulses
# def calculate_mse(data_loader, ATN, DEVICE):
#     mse_data_sim = []  # MSE between data pulses and simulation pulses
#     mse_data_translated = []  # MSE between data pulses and translated pulses

#     # Set the model to evaluation mode and disable gradient calculation
#     ATN.eval()
#     with torch.no_grad():
#         # Iterate over the data loader
#         for wf, wf_deconv, rawwf, x in tqdm(data_loader):
#             # Move input waveforms to the correct device
#             wf = wf.to(DEVICE)
#             wf_deconv = wf_deconv.to(DEVICE)
            
#             # Generate translated waveforms using the model
#             gan_wf = ATN(wf_deconv.float())
            
#             # Iterate over each waveform in the batch
#             for iwf in range(wf.size(0)):
#                 # Get the real waveform, simulated waveform, and translated waveform
#                 real_wf = wf[iwf, 0]  # Real waveform tensor on DEVICE
#                 sim_wf = wf_deconv[iwf, 0]  # Simulated waveform tensor on DEVICE
#                 transfer_wf = gan_wf[iwf, 0]  # Translated waveform tensor on DEVICE
                
#                 # Calculate the MSE between data and simulation waveforms
#                 mse_sim = F.mse_loss(real_wf, sim_wf)  
#                 mse_data_sim.append(mse_sim.item())  

#                 # Calculate the MSE between data and translated waveforms
#                 mse_translated = F.mse_loss(real_wf, transfer_wf)  
#                 mse_data_translated.append(mse_translated.item())

#     # Calculate overall average MSEs
#     average_mse_data_sim = sum(mse_data_sim) / len(mse_data_sim)
#     average_mse_data_translated = sum(mse_data_translated) / len(mse_data_translated)
    
#     return average_mse_data_sim, average_mse_data_translated

# # Call the function with your train_loader
# average_mse_data_sim, average_mse_data_translated = calculate_mse(train_loader, ATN, DEVICE)

# # Print the average MSEs
# print(f"Average MSE between data pulses and simulation pulses: {average_mse_data_sim}")
# print(f"Average MSE between data pulses and translated pulses: {average_mse_data_translated}")


In [None]:
def calculate_tp0(wf_blsub, cross):
    """
    Calculate the tp_0 from a baseline-subtracted waveform by first finding the maximum time point
    and then searching backwards to find when the waveform first crosses amplitude.
    Parameters:
    - wf_blsub (numpy.array): Baseline-subtracted waveform.
    - cross (float) : Threshold for first crosses amplitude
    Returns:
    - tp_0 (float): Calculated time point zero.
    """
    # Ensure wf is a numpy array
    wf_blsub = np.asarray(wf_blsub)
    tp_max = np.argmax(wf_blsub)
    # Using np.where to find the first index meeting the condition
    zero_crossings = np.where(wf_blsub[:tp_max] < cross)[0]
    if zero_crossings.size > 0:
        return zero_crossings[-1]  # Last crossing before max
    return NaN  # Return NaN if no crossing found
def calc_dt(wf, cross):
    """
    Calculate the drift time as the difference between t99 and tp_0.
    """
    t99 = calculate_tn(wf, 95)
    tp_0 = calculate_tp0(wf, cross)
    if np.isnan(t99) or np.isnan(tp_0):
        return np.nan
    return t99 - tp_0

In [None]:
ts = []
gan_ts = []
sim_ts = []
data_ca = []
gan_ca = []
sim_ca = []
data_wf= []
siggen_wf= []
dt_sim = []
dt_data = []
event_eng = []
dt_gan = []
i=0
for wf, wf_deconv,rawwf,c in tqdm(train_loader):
    # if i==20: #processs only 10 batches
    #     break
    bsize = wf.size(0)
    gan_wf = ATN(wf_deconv.to(DEVICE).float())
    for iwf in range(bsize):
        datawf = wf[iwf,0].cpu().numpy().flatten()
        siggenwf = wf_deconv[iwf,0].cpu().numpy().flatten()
        transfer_wf = gan_wf[iwf,0].detach().cpu().numpy().flatten()
        
        ts.append(get_tail_slope(datawf))
        gan_ts.append(get_tail_slope(transfer_wf))
        sim_ts.append(get_tail_slope(siggenwf))
        data_ca.append(calc_current_amplitude(datawf))
        gan_ca.append(calc_current_amplitude(transfer_wf))
        sim_ca.append(calc_current_amplitude(siggenwf))
        siggen_wf.append(siggenwf)
        data_wf.append(datawf)
        dt_sim.append(calc_dt(siggenwf, 0.005))
        dt_data.append(calc_dt(datawf, 0.005))
        dt_gan.append(calc_dt(transfer_wf, 0.005))
        event_eng.append(c["energy"][iwf].cpu().numpy().flatten()[0])
    #     plt.plot(datawf)
    #     plt.axvline(calculate_tp0(datawf, 0.002), color='b', alpha=0.5, label='tp_0')
    #     plt.axvline(calculate_tn(datawf, 99),color='r', alpha=0.5, label='tp_99')
    #     plt.legend()
    #     plt.show()
    #     plt.plot(siggenwf)
    #     plt.axvline(calculate_tp0(siggenwf,0.002), color='b', alpha=0.5, label='tp_0')
    #     plt.axvline(calculate_tn(siggenwf,99),color='r', alpha=0.5, label='tp_99')
    #     plt.legend()
    #     break
    # break
    i += 1

In [None]:
db = np.linspace(250, 2000,100)
dt_sim_plot= np.array(dt_sim)*16
dt_data_plot= np.array(dt_data)*16
dt_gan_plot= np.array(dt_gan)*16

plt.hist(dt_sim_plot, bins=db,histtype="step",linewidth=2,density=False,color="tab:red",alpha=0.6,label="Sim Pulse")
plt.hist(dt_data_plot, bins=db,histtype="step",linewidth=2,density=False,color="tab:blue",alpha=0.6,label="Data Pulse")
plt.hist(dt_gan_plot, bins=db,histtype="step",linewidth=2,density=False,color="tab:green",alpha=0.6,label="Tranlated Pulse")
plt.legend()
plt.ylabel("# of Events")
plt.xlabel("Drift time (ns)")

print(f"Drift Time IoU between Detector Peak and Simulated Peak: {calculate_iou(dt_data_plot, dt_sim_plot, db, normed=False):.10f}")
print(f"Drift Time IoU between Detector Peak and Translated Peak: {calculate_iou(dt_data_plot, dt_gan_plot, db, normed=False):.10f}")

In [None]:
rg = np.linspace(-29.4e-5,-29.2e-5,50)
log_status = False
plt.hist(ts,bins=rg,histtype="step",linewidth=2,log=log_status, color="#1f77b4",label="Detector Pulse", alpha=0.5)
plt.hist(gan_ts,bins=rg,histtype="step",linewidth=2,log=log_status, color="#ff7f0e",label="Translated Pulse", alpha=0.5)
# plt.axvline(x=0,color="#2ca02c",linewidth=3,label="Simulated Pulse")
# plt.xlim(-5,8)
plt.legend()
plt.ylabel("# of Events")
plt.xlabel("Tail Slope")
# plt.savefig("figs/tailslope.png",dpi=200)
# Calculate the histograms (with density=True to normalize the histograms)
plt.legend()
plt.xticks()  # Bigger tick labels
plt.yticks()  # Bigger tick labels
# plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.minorticks_on()
plt.xticks(rotation=45)
print(f"Tail slope IoU between Detector Peak and Translated Peak: {calculate_iou(ts, gan_ts, rg, normed=False):.10f}")

# plt.savefig(f"figs/{eng_peak.upper()}_ts.pdf")


In [None]:
plt.rcParams['font.size'] = 16
plt.rcParams["figure.figsize"] = (9, 8)

rg = np.linspace(0.00002, 0.00015, 70)
plt.hist(data_ca, label="Detector Peak", bins=rg, histtype="step", linewidth=2.5, color="#1f77b4")  # Blue
plt.hist(gan_ca, label="Translated Peak", bins=rg, alpha=0.3, color="#ff7f0e")  # Orange
plt.hist(sim_ca, label="Simulated Peak", bins=rg, histtype="step", linewidth=2.5, color="#2ca02c")  # Green

plt.xlabel("Current Amplitude")
plt.xticks(rotation=45)
plt.ylabel("# of Events")
plt.legend(loc="upper right")
# plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()

plt.legend()
plt.xticks()  # Bigger tick labels
plt.yticks()  # Bigger tick labels
plt.minorticks_on()
plt.savefig(f"figs/{eng_peak.upper()}_amp.pdf")
plt.show()

print(f"IoU between Detector Peak and Simulated Peak: {calculate_iou(data_ca, sim_ca, rg, normed=False):.10f}")
print(f"IoU between Detector Peak and Translated Peak: {calculate_iou(data_ca, gan_ca, rg, normed=False):.10f}")

plt.show()