In [None]:
from IPython import get_ipython
from IPython.display import display
# Cloning the Repository and Installing Dependencies
!git clone https://github.com/emundo/ecgan.git
%cd ecgan
!pip install -e . # The command installs the Python package present in the directory in editable mode
!pip install -r requirements.txt # The command installs the necessary dependencies for the project execution
!pip install wandb adabelief-pytorch pylttb wfdb matplotlib scikit-learn # Installation of additional packages

In [None]:
# Importing Libraries
from google.colab import files # Imports the 'files' module from 'google.colab' for managing files uploaded to Google Colab.
import zipfile # Imports the 'zipfile' module for working with ZIP files, allowing compression and decompression.
import os # Imports the 'os' module to interact with the operating system, for example, for file and directory management.

# Function to upload the ZIP file from computer
def carica_file_zip(): # Defines a function called 'carica_file_zip' to upload a ZIP file to Google Colab.
    uploaded = files.upload() # Uses 'files.upload()' to open a dialog window that allows the user to select and upload a file. The uploaded file is stored in the 'uploaded' dictionary.
    for nome_file in uploaded.keys(): # Iterates over the names of the uploaded files present in the 'uploaded' dictionary.
        print(f"File caricato: {nome_file} ({len(uploaded[nome_file])} bytes)") # Prints a message indicating the name of the uploaded file and its size in bytes.
        with open(nome_file, 'wb') as f: # Opens the file in binary write mode ('wb') and assigns it to the variable 'f'.
            f.write(uploaded[nome_file]) # Writes the content of the uploaded file (present in 'uploaded[nome_file]') to the opened file.

carica_file_zip() # Calls the 'carica_file_zip' function to start the ZIP file upload process.
# Function to decompress the ZIP file
def decomprimi_file_zip(nome_file_zip): # Defines a function called 'decomprimi_file_zip' to decompress a ZIP file.
    with zipfile.ZipFile(nome_file_zip, 'r') as zip_ref: # Opens the specified ZIP file in read mode ('r') and assigns it to the variable 'zip_ref'.
        zip_ref.extractall('intracardiac-atrial-fibrillation-database-1.0.0') # Extracts all files and directories from the ZIP file into the specified directory.
        print(f"File ZIP {nome_file_zip} estratto.") # Prints a message indicating that the ZIP file has been extracted.

decomprimi_file_zip('intracardiac-atrial-fibrillation-database-1.0.0.zip') # Calls the 'decomprimi_file_zip' function to extract the specified ZIP file.

In [None]:
import os
import numpy as np
import wfdb
dataset_dir = '/content/ecgan/intracardiac-atrial-fibrillation-database-1.0.0/intracardiac-atrial-fibrillation-database-1.0.0'  # Defines the dataset directory, specifying the path where the data files are located.
print(f"Contents of {dataset_dir}: {os.listdir(dataset_dir)}") # Prints the content of the dataset directory, listing the files and directories present.
# Create the 'csv' folder if it doesn't exist
csv_dir = os.path.join(dataset_dir, 'csv') # Defines the path of the directory where the converted CSV files will be saved.
os.makedirs(csv_dir, exist_ok=True) # Creates the 'csv' directory if it doesn't already exist, using 'os.makedirs'. The argument 'exist_ok=True' prevents an error if the directory already exists.
# Get the list of all .dat files in the dataset directory
dat_files = [f for f in os.listdir(dataset_dir) if f.endswith('.dat')] # Creates a list of all files with the '.dat' extension present in the dataset directory.

# Convert each .dat file to .csv
for dat_file in dat_files: # Iterates over each '.dat' file in the 'dat_files' list.
    record_name = dat_file[:-4]  # Extracts the record name by removing the '.dat' extension from the file name.
    record_path = os.path.join(dataset_dir, record_name) # Creates the full path to the '.dat' file by combining the dataset directory and the record name.
    record = wfdb.rdrecord(record_path) # Reads the ECG record from the '.dat' file using the 'wfdb.rdrecord' function and assigns it to the variable 'record'.
    signals = record.p_signal  # Extracts the ECG signals from the record and assigns them to the variable 'signals'.
    csv_file_path = os.path.join(csv_dir, record_name + '.csv')  # Defines the path to the CSV file where the signals will be saved.
    np.savetxt(csv_file_path, signals, delimiter=',')  # Saves the ECG signals to a CSV file using NumPy's 'np.savetxt' function. The delimiter ',' specifies that values will be separated by commas.
    print(f"File {record_name}.dat converted to {record_name}.csv") # Prints a message indicating that the '.dat' file has been converted to '.csv'.

print("Conversion completed.") # Prints a message indicating that the conversion of all '.dat' files to '.csv' is complete.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import butter, filtfilt

# Class for managing ECG signal normalization, if "mean" and "std" parameters are not provided, they are automatically calculated
class ECGPreprocessor: # Defines a class called 'ECGPreprocessor' for pre-processing ECG data.
    def __init__(self, normalization_mean=None, normalization_std=None): # Initialization method for the class, which is executed when a new 'ECGPreprocessor' object is created.
        self.mean = normalization_mean # Assigns the value of 'normalization_mean' to the 'mean' attribute of the object.
        self.std = normalization_std # Assigns the value of 'normalization_std' to the 'std' attribute of the object.

    def preprocess(self, data): # Defines a method called 'preprocess' that performs normalization of ECG data.
        if data is None: # Checks if the input data is null.
            return None # If the data is null, returns None.
        if self.mean is None or self.std is None: # Checks if the mean and standard deviation have been specified during initialization.
            self.mean = np.mean(data, axis=0) # If they haven't been specified, calculates the mean of the column data and assigns it to the 'mean' attribute.
            self.std = np.std(data, axis=0) # Calculates the standard deviation of the data along the columns and assigns it to the 'std' attribute.
        normalized_data = (data - self.mean) / self.std # Normalizes the data by subtracting the mean and dividing by the standard deviation.
        return normalized_data # Returns the normalized data.

# Application of a bandpass filter to ECG signals to keep only frequencies within a specific range, Butterworth filter
def bandpass_filter(data, lowcut=0.5, highcut=50, fs=1000, order=4): # Defines a function called 'bandpass_filter' to apply a bandpass filter to ECG data.
    nyquist = 0.5 * fs # Calculates the Nyquist frequency, which is half of the sampling frequency.
    low = lowcut / nyquist # Calculates the normalized lower cutoff frequency.
    high = highcut / nyquist # Calculates the normalized upper cutoff frequency.
    b, a = butter(order, [low, high], btype='band') # Designs a Butterworth bandpass filter using the 'butter' function.
    filtered_data = filtfilt(b, a, data, axis=0) # Applies the filter to the data using the 'filtfilt' function.
    return filtered_data # Returns the filtered data.

# Function to remove the mean from the data to eliminate trend and center the data around zero
def detrend_data(data): # Defines a function called 'detrend_data' to remove the trend from ECG data.
    return data - np.mean(data, axis=0) # Subtracts the mean of the data along axis 0 (columns) to remove the trend.

# Function to preprocess the data
def preprocess_and_inspect_data(ecg_data, preprocessor): # Defines a function called 'preprocess_and_inspect_data' to preprocess and inspect ECG data.
    try: # Starts a try-except block to handle potential exceptions during preprocessing.
        preprocessed_data = preprocessor.preprocess(ecg_data) # Performs normalization of ECG data using the 'preprocessor' object.
        if preprocessed_data is None: # Checks if the preprocessed data is null.
            print("Preprocessed data is None.") # If the data is null, prints an error message.
        else: # If the data is not null, prints information about the preprocessed data.
            print(f"Shape of preprocessed data: {preprocessed_data.shape}")
            print(f"Sample of preprocessed data: {preprocessed_data[:10]}")
        return preprocessed_data # Returns the preprocessed data.
    except Exception as e: # Handles any exceptions that occur during preprocessing.
        print(f"Error preprocessing data: {e}") # Prints an error message with the exception details.
    return None # Returns None if an exception occurs.

# Function to load ECG data and ensure the data has at least two columns
def load_ecg_data(file_path): # Defines a function called 'load_ecg_data' to load ECG data from a CSV file.
    try: # Starts a try-except block to handle potential exceptions during data loading.
        data = np.loadtxt(file_path, delimiter=',') # Loads the data from the CSV file using NumPy's 'np.loadtxt' function.
        if data.ndim == 1: # Checks if the data has only one dimension.
            data = data.reshape(-1, 1) # If the data has only one dimension, reshapes it into a two-dimensional array with a single column.
        print(f"Loaded data shape: {data.shape}") # Prints the shape of the loaded data.
        return data # Returns the loaded data.
    except Exception as e: # Handles any exceptions that occur during data loading.
        print(f"Error loading ECG data: {e}") # Prints an error message with the exception details.
        return None # Returns None if an exception occurs.

# Function to segment data into smaller windows, useful for more detailed analysis
def segment_data(preprocessed_data, window_length=50, window_step_size=25): # Defines a function called 'segment_data' to segment ECG data into smaller windows.
    try: # Starts a try-except block to handle potential exceptions during data segmentation.
        if len(preprocessed_data) < window_length: # Checks if the data length is less than the window length.
            print(f"Data length {len(preprocessed_data)} is less than window length {window_length}. No segmentation will be done.") # If the data length is less than the window length, prints a message and does not perform segmentation.
            return [] # Returns an empty list if segmentation is not performed.

        segments = [ # Creates a list of data segments, where each segment is a data window of length 'window_length'.
            preprocessed_data[i:i + window_length]
            for i in range(0, len(preprocessed_data) - window_length + 1, window_step_size)
        ]
        print(f"Number of segments: {len(segments)}") # Prints the number of segments created.
        if len(segments) > 0: # Checks if any segments were created.
            print(f"Sample segment: {segments[0]}") # If segments were created, prints a sample segment.
        return segments # Returns the list of segments.
    except Exception as e: # Handles any exceptions that occur during data segmentation.
        print(f"Error segmenting data: {e}") # Prints an error message with the exception details.
        return [] # Returns an empty list if an exception occurs.

# Function to visualize ECG signals in three phases: original, filtered, and detrended
def visualize_data(original, filtered=None, detrended=None, normalized=None): # Defines a function called 'visualize_data' to visualize ECG data.
    plt.figure(figsize=(14, 8)) # Creates a new figure with the specified dimensions.
    plt.subplot(3, 1, 1) # Creates a subplot in the figure, organized into 3 rows and 1 column, and selects the first subplot.
    plt.plot(original[:, 0], label='Original Channel 1') # Plots the first channel of the original data.
    if original.shape[1] > 1: # Checks if the data has more than one channel.
        plt.plot(original[:, 1], label='Original Channel 2') # If the data has more than one channel, plots the second channel of the original data.
    plt.title('Original ECG Data') # Sets the title of the subplot.
    plt.xlabel('Samples') # Sets the label for the x-axis.
    plt.ylabel('Amplitude') # Sets the label for the y-axis.
    plt.legend() # Displays the legend for the subplot.

    if filtered is not None: # Checks if filtered data was provided.
        plt.subplot(3, 1, 2) # Selects the second subplot.
        plt.plot(filtered[:, 0], label='Filtered Channel 1') # Plots the first channel of the filtered data.
        if filtered.shape[1] > 1: # Checks if the filtered data has more than one channel.
            plt.plot(filtered[:, 1], label='Filtered Channel 2') # If the filtered data has more than one channel, plots the second channel of the filtered data.
        plt.title('Filtered ECG Data') # Sets the title of the subplot.
        plt.xlabel('Samples') # Sets the label for the x-axis.
        plt.ylabel('Amplitude') # Sets the label for the y-axis.
        plt.legend() # Displays the legend for the subplot.

    if detrended is not None: # Checks if detrended data was provided.
        plt.subplot(3, 1, 3) # Selects the third subplot.
        plt.plot(detrended[:, 0], label='Detrended Channel 1') # Plots the first channel of the detrended data.
        if detrended.shape[1] > 1: # Checks if the detrended data has more than one channel.
            plt.plot(detrended[:, 1], label='Detrended Channel 2') # If the detrended data has more than one channel, plots the second channel of the detrended data.
        plt.title('Detrended ECG Data') # Sets the title of the subplot.
        plt.xlabel('Samples') # Sets the label for the x-axis.
        plt.ylabel('Amplitude') # Sets the label for the y-axis.
        plt.legend() # Displays the legend for the subplot.

    plt.tight_layout() # Automatically adjusts the spacing between subplots to prevent overlaps.
    plt.show() # Displays the figure with the subplots.

# Function to visualize a segment of ECG signals
def visualize_segment(segment): # Defines a function called 'visualize_segment' to visualize a segment of ECG data.
    plt.figure(figsize=(10, 4)) # Creates a new figure with the specified dimensions.
    plt.plot(segment[:, 0], label='Channel 1') # Plots the first channel of the data segment.
    if segment.shape[1] > 1: # Checks if the data segment has more than one channel.
        plt.plot(segment[:, 1], label='Channel 2') # If the data segment has more than one channel, plots the second channel of the data segment.
    plt.title('Segmented ECG Data') # Sets the title of the plot.
    plt.xlabel('Samples') # Sets the label for the x-axis.
    plt.ylabel('Normalized Amplitude') # Sets the label for the y-axis.
    plt.legend() # Displays the legend for the plot.
    plt.show() # Displays the plot.

# Path to the CSV file
file_path = '/content/ecgan/intracardiac-atrial-fibrillation-database-1.0.0/intracardiac-atrial-fibrillation-database-1.0.0/csv/iaf1_afw.csv' # Defines the path to the CSV file containing the ECG data.

# Print the absolute path
absolute_file_path = os.path.abspath(file_path) # Gets the absolute path to the CSV file.
print(f"Looking for file in: {absolute_file_path}") # Prints the absolute path to the CSV file.

if not os.path.isfile(absolute_file_path): # Checks if the CSV file exists at the specified path.
    print(f"File {absolute_file_path} not found.") # If the file is not found, prints an error message.
    test_data = np.random.rand(100, 2)  # Creates random test data with 100 rows and 2 columns.
    os.makedirs(os.path.dirname(absolute_file_path), exist_ok=True) # Creates the directory of the CSV file if it doesn't exist.
    np.savetxt(absolute_file_path, test_data, delimiter=',') # Saves the test data to the CSV file.
    print(f"Test file created at {absolute_file_path}") # Prints a message indicating that the test file has been created.
else: # If the CSV file exists, loads the ECG data from the file.
    ecg_data = load_ecg_data(absolute_file_path) # Calls the 'load_ecg_data' function to load the ECG data from the CSV file.
    print(f"ECG data shape: {ecg_data.shape}") # Prints the shape of the loaded ECG data.
    print(f"ECG data sample: {ecg_data[:10]}")  # Prints the first 10 rows of the loaded ECG data.
    print(f"NaN in ECG data: {np.isnan(ecg_data).any()}")  # Checks if there are any NaN (Not a Number) values in the loaded ECG data.
    # Data processing and visualization
    if ecg_data is not None: # Checks if the ECG data was loaded correctly.
        # Visualize raw data and preprocessing steps
        filtered_data = bandpass_filter(ecg_data, lowcut=0.5, highcut=50, fs=1000) # Applies a bandpass filter to the ECG data.
        print(f"Filtered data shape: {filtered_data.shape}") # Prints the shape of the filtered data.
        print(f"Filtered data sample: {filtered_data[:10]}") # Prints the first 10 rows of the filtered data.
        print(f"NaN in filtered data: {np.isnan(filtered_data).any()}") # Checks if there are any NaN values in the filtered data.
        detrended_data = detrend_data(filtered_data) # Removes the trend from the filtered data.
        preprocessor = ECGPreprocessor() # Creates an 'ECGPreprocessor' object for data normalization.
        preprocessed_data = preprocess_and_inspect_data(detrended_data, preprocessor) # Normalizes the detrended data.
        visualize_data(ecg_data, filtered_data, detrended_data, preprocessed_data) # Visualizes the original, filtered, detrended, and normalized ECG data.
        if preprocessed_data is not None: # Checks if the preprocessed data was created correctly.
        # Segment the data with a smaller window for testing
            window_length = 50  # Adjust window length based on data size
            window_step_size = 25  # Adjust segmentation step size
            segments = segment_data(preprocessed_data, window_length, window_step_size)# Calls the 'segment_data' function to segment the preprocessed data into smaller windows.
            if len(segments) > 0:# Checks if any segments were created.
                random_segment = segments[np.random.randint(len(segments))] # Selects a random segment from the list of segments.
                visualize_segment(random_segment)# Calls the 'visualize_segment' function to visualize the random segment.
            else:
                print("No segments were created.")
    else:
        print("Failed to load ECG data.")


"""We obtain graphs, from the first (original) we get
the original ECG data loaded from your CSV file before
preprocessing steps are applied. we will see the raw signal
or signals over time (samples). If your data has multiple channels
(as indicated by original.shape[1] > 1 in the code), you will see multiple lines representing
each channel.
From the second (filtered) I visualize the ECG data after applying
a bandpass filter. This filter is designed to keep frequencies
within a specific range (from 0.5 Hz to 50 Hz, based on your code)
and remove frequencies outside that range.
From the third (Detrended) I get the ECG data after the mean of the data
has been removed (detrending). This process aims to eliminate
any baseline shift or drift in the signal.
The last graph shows a single segment of the preprocessed data,
the data is divided into smaller windows of fixed length (50 samples)
and the graph shows one of these, from this data then the GAN will learn"""

In [None]:
import torch # Imports the PyTorch library, an open-source machine learning library.
import torch.nn as nn # Imports the 'nn' module from PyTorch, which contains classes and functions for building neural networks.
import torch.optim as optim # Imports the 'optim' module from PyTorch, which contains optimizers for training neural networks.
from torch.utils.data import DataLoader, TensorDataset # Imports the 'DataLoader' and 'TensorDataset' classes from PyTorch's 'data' module, used for data management during training.
import numpy as np # Imports the NumPy library, a fundamental library for scientific computing in Python.
import matplotlib.pyplot as plt # Imports the 'pyplot' module from Matplotlib, a library for creating graphs and visualizations in Python.

# Generator Definition
class Generator(nn.Module): # Defines a class called 'Generator' that inherits from PyTorch's 'nn.Module'. This class represents the GAN's generator.
    def __init__(self, input_dim, output_dim): # Initialization method for the 'Generator' class.
        super(Generator, self).__init__() # Calls the initialization method of the parent class 'nn.Module'.
        self.fc = nn.Sequential( # Defines a sequence of linear layers and activation functions for the generator.
            nn.Linear(input_dim, 128), # Linear layer that maps input of dimension 'input_dim' to an output of dimension 128.
            nn.ReLU(), # ReLU (Rectified Linear Unit) activation function.
            nn.Linear(128, 256), # Linear layer that maps the output of the previous layer to an output of dimension 256.
            nn.ReLU(), # ReLU activation function.
            nn.Linear(256, output_dim), # Linear layer that maps the output of the previous layer to an output of dimension 'output_dim'.
            nn.Tanh() # Tanh (Hyperbolic Tangent) activation function.
        )

    def forward(self, x): # 'forward' method of the 'Generator' class, which defines the forward pass of data through the network.
        return self.fc(x) # Returns the output of the 'fc' sequence of layers applied to the input 'x'.

# Discriminator Definition
class Discriminator(nn.Module): # Defines a class called 'Discriminator' that inherits from PyTorch's 'nn.Module'. This class represents the GAN's discriminator.
    def __init__(self, input_dim): # Initialization method for the 'Discriminator' class.
        super(Discriminator, self).__init__() # Calls the initialization method of the parent class 'nn.Module'.
        self.fc = nn.Sequential( # Defines a sequence of linear layers and activation functions for the discriminator.
            nn.Linear(input_dim, 256), # Linear layer that maps input of dimension 'input_dim' to an output of dimension 256.
            nn.LeakyReLU(0.2), # LeakyReLU activation function with leak parameter 0.2.
            nn.Linear(256, 128), # Linear layer that maps the output of the previous layer to an output of dimension 128.
            nn.LeakyReLU(0.2), # LeakyReLU activation function with leak parameter 0.2.
            nn.Linear(128, 1), # Linear layer that maps the output of the previous layer to an output of dimension 1.
            nn.Sigmoid() # Sigmoid activation function.
        )

    def forward(self, x): # 'forward' method of the 'Discriminator' class, which defines the forward pass of data through the network.
        return self.fc(x) # Returns the output of the 'fc' sequence of layers applied to the input 'x'.

# Function that converts preprocessed data into tensors and loads them into a DataLoader
def create_dataloader(preprocessed_data, batch_size=64): # Defines a function called 'create_dataloader' that creates a DataLoader for preprocessed data.
    data_tensor = torch.tensor(preprocessed_data, dtype=torch.float32) # Converts the preprocessed data into a PyTorch tensor.
    dataset = TensorDataset(data_tensor) # Creates a TensorDataset from the tensor data.
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) # Creates a DataLoader from the dataset, with the specified batch size and with shuffle enabled.
    return dataloader # Returns the DataLoader.

# Function to calculate the mean squared error between real data and generated data as an anomaly measure
def calculate_anomalies(real_data, fake_data): # Defines a function called 'calculate_anomalies' that calculates the mean squared error between real data and generated data.
    return torch.mean((real_data - fake_data) ** 2).item() # Calculates the mean squared error between the real data and the generated data and returns the scalar value.

# Determination of the optimal threshold for detecting anomalies as the 95th percentile of errors
def determine_optimal_threshold(generator, dataloader, noise_dim=100): # Defines a function called 'determine_optimal_threshold' that determines the optimal threshold for anomaly detection.
    generator.eval() # Sets the generator to evaluation mode.
    errors = [] # Initializes an empty list to store reconstruction errors.
    with torch.no_grad(): # Disables gradient calculation during the loop.
        for data in dataloader: # Iterates over data batches in the DataLoader.
            real_data = data[0].to(device) # Gets the real data from the batch and moves it to the specified device.
            batch_size = real_data.size(0) # Gets the batch size.
            noise = torch.randn(batch_size, noise_dim).to(device) # Generates random noise with the specified dimension and moves it to the specified device.
            fake_data = generator(noise) # Generates fake data using the generator and random noise.

            # Calculate the reconstruction error for each example
            for real, fake in zip(real_data, fake_data): # Iterates over real and fake data in the batch.
                error = calculate_anomalies(real.unsqueeze(0), fake.unsqueeze(0)) # Calculates the reconstruction error between the real and fake data.
                errors.append(error) # Appends the reconstruction error to the list of errors.
    # Calculate the threshold as the 95th percentile of reconstruction errors
    threshold = np.percentile(errors, 95) # Calculates the 95th percentile of reconstruction errors and uses it as the threshold.
    return threshold # Returns the threshold.

# Generation and saving of samples
def generate_and_save_samples(generator, dataloader, epoch, noise_dim=100, threshold=None): # Defines a function called 'generate_and_save_samples' that generates and saves samples of generated data.
    generator.eval() # Sets the generator to evaluation mode.
    with torch.no_grad(): # Disables gradient calculation during the loop.
        noise = torch.randn(16, noise_dim).to(device) # Generates random noise with the specified dimension and moves it to the specified device.
        fake_data = generator(noise).cpu().numpy() # Generates fake data using the generator and random noise, then moves it to the CPU and converts it to a NumPy array.
        real_batch = next(iter(dataloader))[0].cpu().numpy()  # Get a batch of real data # Gets a batch of real data from the DataLoader, moves it to the CPU, and converts it to a NumPy array.
        mse_values = [] # Initializes an empty list to store MSE (Mean Squared Error) values.
        for i in range(len(fake_data)): # Iterates over the generated fake data.
            if i < len(real_batch): # Checks if there is still real data available.
                real_sample = torch.tensor(real_batch[i], dtype=torch.float32).unsqueeze(0).to(device) # Converts the real data sample into a PyTorch tensor, moves it to the specified device, and adds an extra dimension.
                fake_sample = torch.tensor(fake_data[i], dtype=torch.float32).unsqueeze(0).to(device) # Converts the fake data sample into a PyTorch tensor, moves it to the specified device, and adds an extra dimension.
                mse_value = calculate_anomalies(real_sample, fake_sample) # Calculates the MSE between the real and fake data samples.
                mse_values.append(mse_value) # Appends the MSE value to the list of MSE values.
            else: # If no more real data is available, use a high value # If no more real data is available, appends a high value to the list of MSE values to indicate that MSE cannot be calculated.
                mse_values.append(float('inf'))  # If no more real data, use a high value
        if threshold is not None: # Checks if a threshold has been specified.
            num_anomalies = sum(mse > threshold for mse in mse_values) # Calculates the number of anomalies detected by comparing the MSE values with the threshold.
        else: # If no threshold has been specified, sets the number of anomalies to None.
            num_anomalies = None

    # Visualize generated data and real data (only a subset)
    plt.figure(figsize=(15, 10)) # Creates a new figure with the specified dimensions.
    num_samples = min(16, len(real_batch)) # Calculates the number of samples to visualize, taking the minimum between 16 and the length of the real data batch.
    for i in range(num_samples): # Iterates over the samples to visualize.
        plt.subplot(4, 4, i + 1) # Creates a subplot in the figure, organized into 4 rows and 4 columns, and selects the current subplot.
        plt.plot(fake_data[i], label='Generated', color='blue') # Plots the generated fake data in the current subplot.
        plt.plot(real_batch[i], label='Real', color='red', linestyle='--') # Plots the real data in the current subplot.
        plt.title(f'Sample {i+1}\nMSE: {mse_values[i]:.4f}') # Sets the title of the current subplot, including the sample number and the MSE value.
        plt.axis('off') # Disables the axes of the current subplot.
    if num_anomalies is not None: # Checks if the number of anomalies has been calculated.
        plt.suptitle(f'Anomalies detected: {num_anomalies}/{num_samples}\nThreshold: {threshold:.4f}', fontsize=16) # Sets the overall title of the figure, including the number of anomalies detected and the threshold.
    else: # If the number of anomalies has not been calculated, sets the overall title of the figure to "No threshold applied".
        plt.suptitle('No threshold applied', fontsize=16)
    plt.savefig(f'gan_samples_epoch_{epoch+1}.png') # Saves the figure to a file.
    plt.show() # Displays the figure.

    # Anomaly detection plot
    if threshold is not None: # Checks if a threshold has been specified.
        plt.figure(figsize=(10, 6)) # Creates a new figure with the specified dimensions.
        plt.hist(mse_values, bins=20, color='blue', alpha=0.7, label='MSE values') # Creates a histogram of the MSE values.
        plt.axvline(threshold, color='red', linestyle='dashed', linewidth=1.5, label='Threshold') # Draws a vertical line representing the threshold.
        plt.xlabel('MSE Value') # Sets the label for the x-axis.
        plt.ylabel('Frequency') # Sets the label for the y-axis.
        plt.title('Distribution of MSE Values') # Sets the title of the plot.
        plt.legend() # Displays the legend for the plot.
        plt.savefig(f'anomaly_detection_epoch_{epoch+1}.png') # Saves the figure to a file.
        plt.show() # Displays the figure.

# Training the GAN
def train_gan(generator, discriminator, dataloader, num_epochs=10, lr=0.0002, beta1=0.5): # Defines a function called 'train_gan' that trains the GAN.
    criterion = nn.BCELoss() # Defines the loss function as Binary Cross Entropy Loss.
    optim_gen = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999)) # Defines the optimizer for the generator as Adam.
    optim_disc = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999)) # Defines the optimizer for the discriminator as Adam.

    # Lists to track losses
    loss_d_list = [] # Initializes an empty list to store discriminator losses.
    loss_g_list = [] # Initializes an empty list to store generator losses.

    # Calculate the optimal threshold for anomalies
    threshold = determine_optimal_threshold(generator, dataloader, noise_dim=noise_dim) # Calls the 'determine_optimal_threshold' function to calculate the optimal threshold for anomaly detection.

    for epoch in range(num_epochs): # Iterates over the number of training epochs.
        for data in dataloader:
          real_data = data[0].to(device)
            batch_size = real_data.size(0)

            # Creazione di etichette per i dati veri e falsi
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Addestramento del Discriminatore
            discriminator.zero_grad()
            output = discriminator(real_data)
            real_loss = criterion(output, real_labels)
            real_loss.backward()

            noise = torch.randn(batch_size, noise_dim).to(device)
            fake_data = generator(noise)
            output = discriminator(fake_data.detach())
            fake_loss = criterion(output, fake_labels)
            fake_loss.backward()
            optim_disc.step()

            # Addestramento del Generatore
            generator.zero_grad()
            output = discriminator(fake_data)
            gen_loss = criterion(output, real_labels)
            gen_loss.backward()
            optim_gen.step()

        # Traccia le perdite
        loss_d_list.append(real_loss.item() + fake_loss.item())
        loss_g_list.append(gen_loss.item())

        print(f'Epoch [{epoch+1}/{num_epochs}] | Loss D: {loss_d_list[-1]} | Loss G: {loss_g_list[-1]}')
        generate_and_save_samples(generator, dataloader, epoch, noise_dim=noise_dim, threshold=threshold)

    # Visualizzazione delle perdite
    plt.figure(figsize=(12, 6))
    plt.plot(loss_d_list, label='Discriminator Loss')
    plt.plot(loss_g_list, label='Generator Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Losses During Training')
    plt.legend()
    plt.savefig('training_losses.png')
    plt.show()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Dimensioni
noise_dim = 100
data_dim = preprocessed_data.shape[1]

# Creazione dei modelli
generator = Generator(noise_dim, data_dim).to(device)
discriminator = Discriminator(data_dim).to(device)

# Creazione del DataLoader
dataloader = create_dataloader(preprocessed_data)

# Addestramento del GAN
train_gan(generator, discriminator, dataloader, num_epochs=10)# Iter

In [None]:
import torch # Imports the PyTorch library for deep learning.
import numpy as np # Imports the NumPy library for numerical computation.
import matplotlib.pyplot as plt # Imports Matplotlib for data visualization.
from sklearn.metrics import mean_squared_error, roc_curve, auc, precision_recall_curve, average_precision_score, confusion_matrix, classification_report # Imports metrics from scikit-learn for model evaluation.

# Displays a number of real and generated data samples for comparison. Each pair of graphs shows a real data and the corresponding generated data.
def plot_generated_vs_real(real_data, generated_data, num_samples=5): # Defines a function to visualize real and generated data.
    plt.figure(figsize=(15, 10)) # Sets the figure size.
    for i in range(num_samples): # Iterates over a specified number of samples.
        plt.subplot(num_samples, 2, 2*i + 1) # Creates a subplot for real data.
        plt.plot(real_data[i], color='blue', label='Real') # Plots real data in blue.
        plt.title(f'Real Data Sample {i+1}') # Sets the subplot title.
        plt.axis('off') # Disables axes.
        plt.subplot(num_samples, 2, 2*i + 2) # Creates a subplot for generated data.
        plt.plot(generated_data[i], color='red', label='Generated') # Plots generated data in red.
        plt.title(f'Generated Data Sample {i+1}') # Sets the subplot title.
        plt.axis('off') # Disables axes.
    plt.tight_layout() # Adjusts layout to prevent overlaps.
    plt.show() # Shows the figure.

# Calculation of the mean squared error between real and generated data (how much the generated data differs from the real data)
def calculate_mse(real_data, generated_data): # Defines a function to calculate the mean squared error (MSE).
    return np.mean((real_data - generated_data) ** 2) # Calculates and returns the MSE.

# Plot of D and G losses during training, saving the graph as an image
def plot_losses(loss_d_list, loss_g_list): # Defines a function to plot discriminator and generator losses during training.
    plt.figure(figsize=(12, 6)) # Sets the figure size.
    plt.plot(loss_d_list, label='Discriminator Loss') # Plots the discriminator loss.
    plt.plot(loss_g_list, label='Generator Loss') # Plots the generator loss.
    plt.xlabel('Epoch') # Sets the x-axis label.
    plt.ylabel('Loss') # Sets the y-axis label.
    plt.title('Losses During Training') # Sets the plot title.
    plt.legend() # Shows the legend.
    plt.savefig('training_losses.png') # Saves the graph as an image.
    plt.show() # Shows the graph.

# Function to calculate reconstruction errors between real and generated data using MSE for each pair of data
def calculate_reconstruction_errors(real_data, generated_data): # Defines a function to calculate reconstruction errors.
    errors = np.array([mean_squared_error(real.flatten(), gen.flatten()) for real, gen in zip(real_data, generated_data)]) # Calculates the MSE for each pair of real and generated data.
    return errors # Returns the reconstruction errors.

# Histogram showing reconstruction errors and the threshold used to detect anomalies
def plot_reconstruction_errors(errors, threshold): # Defines a function to plot reconstruction errors.
    plt.figure(figsize=(10, 6)) # Sets the figure size.
    plt.hist(errors, bins=50, color='blue', alpha=0.7) # Creates a histogram of reconstruction errors.
    plt.axvline(threshold, color='red', linestyle='dashed', linewidth=1.5, label='Threshold') # Draws a vertical line for the threshold.
    plt.xlabel('Reconstruction Error') # Sets the x-axis label.
    plt.ylabel('Frequency') # Sets the y-axis label.
    plt.title('Distribution of Reconstruction Errors') # Sets the plot title.
    plt.legend() # Shows the legend.
    plt.show() # Shows the graph.

# Function to prepare labels and reconstruction error scores. 1 below the threshold and 0 otherwise
def prepare_labels_and_scores(real_data, generated_data, threshold): # Defines a function to prepare labels and scores for evaluation.
    mse_scores = [] # Initializes a list for MSE scores.
    true_labels = [] # Initializes a list for true labels.
    for real, fake in zip(real_data, generated_data): # Iterates over pairs of real and generated data.
        mse = calculate_mse(real, fake) # Calculates the MSE.
        mse_scores.append(mse) # Adds the MSE score to the list.
        true_labels.append(0 if mse > threshold else 1)  # Assigns label 1 if MSE is below the threshold, otherwise 0.
    return np.array(true_labels), np.array(mse_scores) # Returns labels and scores as NumPy arrays.

# Plot of the precision-recall curve and calculates the average precision score
def plot_precision_recall_curve(true_labels, scores): # Defines a function to plot the precision-recall curve.
    precision, recall, _ = precision_recall_curve(true_labels, scores) # Calculates precision and recall.
    ap_score = average_precision_score(true_labels, scores) # Calculates the average precision score.

    plt.figure(figsize=(10, 6)) # Sets the figure size.
    plt.plot(recall, precision, marker='.', label='Precision-Recall curve (AP = %0.2f)' % ap_score) # Plots the precision-recall curve.
    plt.xlabel('Recall') # Sets the x-axis label.
    plt.ylabel('Precision') # Sets the y-axis label.
    plt.title('Precision-Recall Curve') # Sets the plot title.
    plt.legend() # Shows the legend.
    plt.grid() # Shows the grid.
    plt.show() # Shows the graph.

# Plot of the confusion matrix