In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy as sp
from dataclasses import dataclass, field

# Nested configuration structure
@dataclass(frozen=True)
class SignalConfig:
    """Signal parameters configuration"""
    fc: float
    fsymbol: float
    fs: float
    up_fs: float
    fs_num: int
    up_fs_num: int
    
    # Calculated properties
    symbol_duration: float = field(init=False)
    
    def __post_init__(self):
        object.__setattr__(self, 'symbol_duration', 1 / self.fsymbol)

@dataclass(frozen=True)
class CSVConfig:
    """CSV file paths configuration"""
    original: str
    t0: str
    pn_t0: str
    fs: str
    up_fs: str

@dataclass(frozen=True)
class AWGNConfig:
    """AWGN parameters configuration"""
    snr_db: float
    signal_power: float

    # Calculated properties
    snr_linear: float = field(init=False)
    noise_power: float = field(init=False)

    def __post_init__(self):
        object.__setattr__(self, 'snr_linear', 10 ** (self.snr_db / 10))
        object.__setattr__(self, 'noise_power', self.signal_power / self.snr_linear)

@dataclass(frozen=True)
class PhaseNoiseConfig:
    """Phase noise parameters configuration"""
    std_rad: float  # Standard deviation in radians
    
    # Calculated properties
    std_degree: float = field(init=False)
    std_time: float = field(init=False)
    
    def __post_init__(self):
        object.__setattr__(self, 'std_degree', self.std_rad * 180 / np.pi)
        object.__setattr__(self, 'std_time', self.std_rad / (2 * np.pi * config.signal.fc))

@dataclass(frozen=True)
class Config:
    """Main configuration class"""
    signal: SignalConfig
    csv: CSVConfig
    awgn: AWGNConfig
    pn: PhaseNoiseConfig


# Create configuration instance
config = Config(
    signal=SignalConfig(
        fc=4e9,
        fsymbol=500e6,
        fs=32e9,
        up_fs=1024e9,
        fs_num=32,
        up_fs_num=1024
    ),
    csv=CSVConfig(
        original='../csv/8PSK2PPM_500MBps.csv',
        t0='../csv/t0.csv',
        pn_t0='../csv/pn_t0.csv',
        fs='../csv/fs.csv',
        up_fs='../csv/up_fs.csv'
    ),
        awgn=AWGNConfig(
        snr_db=15,
        signal_power=0.5
    ),
        pn=PhaseNoiseConfig(
        std_rad=0.27
    )
)

## Pre-configure

In [None]:
# pre-configure the original file:
# 1. change the header to be time,data
# 2. make sure the time starts from 0
# 3. make the 99% of the data to be 0.99

# Read CSV file
df = pd.read_csv(config.csv.original)

# Get column names and rename to time and data
columns = df.columns.tolist()
df.columns = ['time', 'data']

# Ensure time starts from 0
if len(df) > 0:
    time_start = df['time'].iloc[0]
    df['time'] = df['time'] - time_start

# Scale, make the 99% of the data to be 0.99
factor = 0.99 / df['data'].quantile(0.99)
df['data'] = df['data'] * factor

# Save processed file
df.to_csv(config.csv.t0, index=False)

df.describe()


## Add phase noise

In [None]:


data = pd.read_csv(config.csv.t0)

# generate phase noise
np.random.seed(42)
phase_noise = np.random.normal(0, config.pn.std_rad, len(data))

# Calculate time jitter noise
# Relationship between time jitter and phase noise: Δt = Δφ / (2π * f_carrier)
time_jitter = phase_noise / (2 * np.pi * config.signal.fc)

# Create output dataframe, keep only time and data columns
output_data = pd.DataFrame()

# Add time column
output_data['time'] = data['time']

# Apply time jitter noise to the original signal
# Get original time axis
t_original = np.array(data['time'].values)
t_jittered = t_original + time_jitter
data_original = np.array(data['data'].values)

# Use interpolation to get signal values at jittered time points
# from scipy.interpolate import interp1d

# Create interpolation function
# Use linear interpolation, extrapolate for boundaries
interp_func = sp.interpolate.interp1d(t_original, data_original,
                        kind='linear',
                        bounds_error=False,
                        fill_value=0)

# Sample at jittered time points
noisy_signal = interp_func(t_jittered)

output_data['data'] = noisy_signal

# save to csv
output_data.to_csv(config.csv.pn_t0, index=False)

plot_data = output_data.head(3000)

# Read and plot the first 1000 points
plt.figure(figsize=(12, 5))
plt.plot(plot_data['time'] * 1e9, plot_data['data'], label='data')
plt.xlabel('Time (ns)')
plt.ylabel('Signal value')
plt.title('First 1000 points of phase noise/time jitter signal')
plt.grid(True, alpha=0.3)
plt.legend()
plt.tight_layout()
plt.show()

## Downsampling

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Define the target resampling frequency and interval
target_frequency_hz = config.signal.fs
resampling_interval_s = 1 / target_frequency_hz

# Read the CSV file, skipping the original header row to replace it later
df = pd.read_csv(config.csv.pn_t0, header=0)

# Rename columns for clarity based on the original header structure
# Assuming the first column is time-like and second is data-like
df.columns = ['original_time', 'original_data']

# Convert columns to numeric, coercing errors if any
df['original_time'] = pd.to_numeric(df['original_time'], errors='coerce')
df['original_data'] = pd.to_numeric(df['original_data'], errors='coerce')

# Drop rows with NaN values that might have resulted from coercion
df.dropna(subset=['original_time', 'original_data'], inplace=True)

# Prepare for resampling
# The new time axis will start from 0 (because df['original_time'] now starts from 0)
# and go up to the maximum duration of the adjusted time
start_resample_time = 0
end_resample_time = df['original_time'].max()

new_time_axis = np.arange(start_resample_time, end_resample_time, resampling_interval_s)

# Perform linear interpolation
# np.interp needs the original x-values (df['time']) to be sorted
# Assert that the data is already sorted by time
assert df['original_time'].is_monotonic_increasing, "Data must be sorted by time"
resampled_data_values = np.interp(new_time_axis, df['original_time'], df['original_data'])

# Create a new DataFrame for the resampled data
df_resampled = pd.DataFrame({'time': new_time_axis, 'data': resampled_data_values})

# Display the resampled data information without saving to file
print(f"Data resampled to {config.signal.fs_num} GHz.")
print(df_resampled.head())

# Display comparison of original and resampled signals for the first 30ns
time_limit = 1e-8  # 10 ns

# Filter data for the first 30ns
mask_original = df['original_time'] <= time_limit
mask_resampled = df_resampled['time'] <= time_limit

# Create comparison plot
plt.figure(figsize=(15, 8))

# Top plot: Original signal
plt.subplot(2, 1, 1)
plt.plot(df.loc[mask_original, 'original_time'] * 1e9, 
        df.loc[mask_original, 'original_data'], 
        'b-', linewidth=1, alpha=0.8, label='Original Signal')
plt.xlabel('Time (ns)')
plt.ylabel('Amplitude')
plt.title('Original Signal - First 30ns')
plt.grid(True, alpha=0.3)
plt.legend()

# Bottom plot: Signal comparison
plt.subplot(2, 1, 2)
plt.plot(df.loc[mask_original, 'original_time'] * 1e9, 
        df.loc[mask_original, 'original_data'], 
        'b-', linewidth=1, alpha=0.6, label=f'Original Signal')
plt.plot(df_resampled.loc[mask_resampled, 'time'] * 1e9, 
        df_resampled.loc[mask_resampled, 'data'], 
        'r-', linewidth=1, alpha=0.8, label=f'Resampled Signal ({config.signal.fs_num} GHz)')
plt.xlabel('Time (ns)')
plt.ylabel('Amplitude')
plt.title(f'Signal Comparison - First 30ns (Original vs {config.signal.fs_num} GHz Resampled)')
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

# Print statistics
print(f"\n=== First 30ns Signal Statistics ===")
print(f"Original signal points: {mask_original.sum()}")
print(f"Resampled signal points: {mask_resampled.sum()}")
print(f"Original sampling rate: {1/df['original_time'].diff().mean()/1e9:.2f} GHz (estimated)")
print(f"Resampled rate: {target_frequency_hz/1e9:.2f} GHz")

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

# Define the target resampling frequency and interval
target_frequency_hz = config.signal.up_fs
resampling_interval_s = 1 / target_frequency_hz

# Read the CSV file, skipping the original header row to replace it later
df = pd.read_csv(config.csv.pn_t0, header=0)

# Rename columns for clarity based on the original header structure
# Assuming the first column is time-like and second is data-like
df.columns = ['original_time', 'original_data']

# Convert columns to numeric, coercing errors if any
df['original_time'] = pd.to_numeric(df['original_time'], errors='coerce')
df['original_data'] = pd.to_numeric(df['original_data'], errors='coerce')

# Drop rows with NaN values that might have resulted from coercion
df.dropna(subset=['original_time', 'original_data'], inplace=True)

# Prepare for resampling
# The new time axis will start from 0 (because df['original_time'] now starts from 0)
# and go up to the maximum duration of the adjusted time
start_resample_time = 0
end_resample_time = df['original_time'].max()

new_time_axis = np.arange(start_resample_time, end_resample_time, resampling_interval_s)

# Perform linear interpolation
# np.interp needs the original x-values (df['time']) to be sorted
# Assert that the data is already sorted by time
assert df['original_time'].is_monotonic_increasing, "Data must be sorted by time"
resampled_data_values = np.interp(new_time_axis, df['original_time'], df['original_data'])

# Create a new DataFrame for the resampled data
df_upsampled = pd.DataFrame({'time': new_time_axis, 'data': resampled_data_values})

# Display the resampled data information without saving to file
print(f"Data resampled to {config.signal.up_fs_num} GHz.")
print(df_upsampled.head())

# Display comparison of original and resampled signals for the first 30ns
time_limit = 1e-8  # 10 ns

# Filter data for the first 30ns
mask_original = df['original_time'] <= time_limit
mask_upsampled = df_upsampled['time'] <= time_limit



# temp！！！

# df_upsampled = df_upsampled[df_upsampled['time'] >= 0]
# Extract data points with time >= 1.0781e-09
df_upsampled = df_upsampled[df_upsampled['time'] >= 1.0881e-09]
df_upsampled['time'] = df_upsampled['time'] - 1.0781e-09
df_upsampled.reset_index(drop=True, inplace=True)

plt.plot(df_upsampled['time'][:2000], df_upsampled['data'][:2000])
plt.show()

df_upsampled.describe()




# Create comparison plot
plt.figure(figsize=(15, 8))

# Top plot: Original signal
plt.subplot(2, 1, 1)
plt.plot(df.loc[mask_original, 'original_time'] * 1e9, 
        df.loc[mask_original, 'original_data'], 
        'b-', linewidth=1, alpha=0.8, label='Original Signal')
plt.xlabel('Time (ns)')
plt.ylabel('Amplitude')
plt.title('Original Signal - First 30ns')
plt.grid(True, alpha=0.3)
plt.legend()

# Bottom plot: Signal comparison
plt.subplot(2, 1, 2)
plt.plot(df.loc[mask_original, 'original_time'] * 1e9, 
        df.loc[mask_original, 'original_data'], 
        'b-', linewidth=1, alpha=0.6, label=f'Original Signal')
plt.plot(df_upsampled.loc[mask_upsampled, 'time'] * 1e9, 
        df_upsampled.loc[mask_upsampled, 'data'], 
        'r-', linewidth=1, alpha=0.8, label=f'Resampled Signal ({config.signal.up_fs_num} GHz)')
plt.xlabel('Time (ns)')
plt.ylabel('Amplitude')
plt.title(f'Signal Comparison - First 30ns (Original vs {config.signal.up_fs_num} GHz Resampled)')
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

# Print statistics
print(f"\n=== First 30ns Signal Statistics ===")
print(f"Original signal points: {mask_original.sum()}")
print(f"Resampled signal points: {mask_upsampled.sum()}")
print(f"Original sampling rate: {1/df['original_time'].diff().mean()/1e9:.2f} GHz (estimated)")
print(f"Resampled rate: {target_frequency_hz/1e9:.2f} GHz")



#

In [None]:
# Calculate the square of the data to find power
df_resampled['Data_Squared'] = df_resampled['data'] ** 2

# Define the period for folding (assuming 500 Mbps symbol rate -> 2 ns period)
period = 2e-9  # 2 ns

# Use the squared data for folding analysis
data_to_fold = df_resampled['Data_Squared'].dropna()
time_to_fold = df_resampled.loc[data_to_fold.index, 'time']

# Calculate the time modulo the period
folded_time = time_to_fold % period

# Determine the time resolution
time_resolution = time_to_fold.diff().mean()
if pd.isna(time_resolution):
    time_resolution = (df_resampled['time'].iloc[1] - df_resampled['time'].iloc[0]) if len(df_resampled['time']) > 1 else 1e-12

# Create bins for the 0-2ns range
num_bins = max(1, int(period / time_resolution))
bins = np.linspace(0, period, num_bins + 1)

# Create a dataframe for folding
fold_df = pd.DataFrame({'time': folded_time, 'data': data_to_fold})

# Digitize the folded time to assign each time point to a bin
fold_df['time_bin'] = pd.cut(fold_df['time'], bins=bins, labels=False, include_lowest=True)

# Group by the bins and sum the data
summed_data = fold_df.groupby('time_bin')['data'].sum()

# Create the time axis for the summed data (using the middle of each bin)
bin_centers = (bins[:-1] + bins[1:]) / 2

# Reindex the summed data to match the bins, filling missing bins with 0
summed_data = summed_data.reindex(range(len(bin_centers)), fill_value=0)

# Normalize the summed data to be between 0 and 1
summed_data = (summed_data - summed_data.min()) / (summed_data.max() - summed_data.min())

# Plot the folded result
plt.figure(figsize=(12, 6))
plt.plot(bin_centers, summed_data)
plt.xlabel('Time (s) within 2ns period')
plt.ylabel('Summed Data (Normalized)')
plt.title('Data folded and summed over a 2ns period')
plt.grid(True)
plt.show()

# Find the midpoint of the interval where data is above a threshold
threshold = 0.2

# Find where data is greater than threshold
above_threshold = summed_data > threshold

if above_threshold.any():
    # Get the time values for these points
    time_above_threshold = bin_centers[above_threshold]

    # Calculate the circular mean of the time points to handle wrap-around
    # Convert time to angles (radians)
    angles = (time_above_threshold / period) * 2 * np.pi

    # Compute the mean of the sines and cosines of the angles
    mean_sin = np.mean(np.sin(angles))
    mean_cos = np.mean(np.cos(angles))

    # Calculate the mean angle from the mean sine and cosine
    mean_angle = np.arctan2(mean_sin, mean_cos)

    # Convert the mean angle back to time
    midpoint_time = (mean_angle / (2 * np.pi)) * period

    # Adjust the midpoint to be in the [0, period] range
    if midpoint_time < 0:
        midpoint_time += period

    print(f"The midpoint of the interval where data is > {threshold} is: {midpoint_time:.4e} s")
    print(f"Need to delay the signal by {period - midpoint_time:.4e} s to align with the end of period.")
    print(f"This corresponds to {int((period - midpoint_time) * config.signal.fs)} samples at {config.signal.fs_num} GHz sampling rate.")

    # Plot the result with midpoint visualization
    plt.figure(figsize=(12, 6))
    plt.plot(bin_centers, summed_data, label='Summed Data')
    plt.axhline(y=threshold, color='r', linestyle='--', label=f'Threshold ({threshold})')
    plt.axvline(x=midpoint_time, color='g', linestyle='-', label=f'Midpoint ({midpoint_time:.2e} s)')
    plt.xlabel('Time (s) within 2ns period')
    plt.ylabel('Summed Data (Normalized)')
    plt.title('Data folded and summed over a 2ns period with Midpoint')
    plt.grid(True)
    plt.legend()
    plt.show()

else:
    print(f"No data points found above the threshold of {threshold}.")


In [None]:
# Remove the first (midpoint_time + 1e-9) seconds of data
# Calculate the time threshold
time_threshold = midpoint_time + 1e-9  # Add 1 ns to midpoint_time

# Find the index where time first exceeds the threshold
indices_to_remove = df_resampled['time'] < time_threshold
num_points_to_remove = indices_to_remove.sum()

print(f"Time threshold: {time_threshold:.4e} s")
print(f"Number of data points to remove: {num_points_to_remove}")

# Create the trimmed dataframe
df_trimmed = df_resampled[~indices_to_remove].copy()

# Reset the time axis to start from 0 again
df_trimmed['time'] = df_trimmed['time'] - df_trimmed['time'].min()

# Reset the index
df_trimmed.reset_index(drop=True, inplace=True)

print(f"Original data shape: {df_resampled.shape}")
print(f"Trimmed data shape: {df_trimmed.shape}")
print(f"Data points removed: {df_resampled.shape[0] - df_trimmed.shape[0]}")
print(f"First few rows of trimmed data:")
print(df_trimmed.head())

# Plot the trimmed data
plt.figure(figsize=(12, 6))
plt.plot(df_trimmed['time'], df_trimmed['data'])
plt.xlabel('Time (s)')
plt.ylabel('Data')
plt.title('Trimmed Data vs Time')
plt.grid(True)
plt.xlim(0, 1e-8)  # Show 0-10 ns
# Set x-axis ticks every 2 ns
plt.xticks(np.arange(0, 1e-8 + 2e-9, 2e-9))
plt.show()

# Save the trimmed data to CSV file
# Only save time and data columns (exclude Data_Squared)
df_to_save = pd.DataFrame(df_trimmed[['time', 'data']])
df_to_save.to_csv(config.csv.fs, index=False)
print(f"Trimmed data saved to {config.csv.fs}")
print(f"Saved data shape: {df_to_save.shape}")
print(f"Columns saved: {list(df_to_save.columns)}")

In [None]:
df_awgn = pd.DataFrame(df_trimmed[['time', 'data']])

# Generate AWGN noise
noise = np.random.normal(loc=0, scale=np.sqrt(config.awgn.noise_power), size=len(df_awgn))

# Add noise to the signal
df_awgn['data'] = df_awgn['data'] + noise

print(f"Added AWGN noise: SNR={config.awgn.snr_db} dB (linear={config.awgn.snr_linear:.2f}), signal power={config.awgn.signal_power}, noise power={config.awgn.noise_power:.2e}")
print(df_awgn.head())


noise = np.random.normal(loc=0, scale=np.sqrt(config.awgn.noise_power), size=len(df_upsampled))
df_upsampled['data'] = df_upsampled['data'] + noise

In [None]:
# # Method 1: FFT-based zero-padding upsampling from fs_num GHz to 2048 GHz
# print(f"=== FFT Zero-Padding Upsampling: {config.signal.fs_num} GHz → {config.signal.up_fs_num} GHz ===")

# # Use the loaded data from the CSV file
# x_orig = df['data'].to_numpy(dtype=np.float64)
# N = len(x_orig)
# upsample_factor = int(config.signal.up_fs_num / config.signal.fs_num)
# print(f"Original data length: {N}")
# print(f"Original sampling rate: {config.signal.fs_num} GHz")
# print(f"Target sampling rate: {config.signal.up_fs_num} GHz")
# print(f"Upsampling factor: {upsample_factor}×")

# # Step 1: FFT of original data
# X = np.fft.fft(x_orig)

# # Step 2: Create zero-padded frequency domain signal
# N_new = N * upsample_factor
# X_padded = np.zeros(N_new, dtype=complex)

# # For even N: split the Nyquist frequency component
# if N % 2 == 0:
#     # Copy positive frequencies [0, N/2]
#     X_padded[:N//2] = X[:N//2]
#     # Copy negative frequencies [N/2+1, N-1] to the end
#     X_padded[N_new-N//2+1:] = X[N//2+1:]
#     # Split Nyquist frequency (if it exists)
#     X_padded[N//2] = X[N//2] / 2
#     X_padded[N_new-N//2] = X[N//2] / 2
# else:
#     # For odd N: simpler case
#     X_padded[:(N+1)//2] = X[:(N+1)//2]
#     X_padded[N_new-(N-1)//2:] = X[(N+1)//2:]

# # Step 3: IFFT to get upsampled signal
# x_upsampled = np.fft.ifft(X_padded).real * upsample_factor  # Scale by upsampling factor

# # Step 4: Create precise time axis
# Ts_orig = 1 / config.signal.fs  # Original sampling period
# Ts_upsampled = Ts_orig / upsample_factor  # New sampling period
# t_upsampled = np.arange(len(x_upsampled), dtype=np.float64) * Ts_upsampled

# print(f"Upsampled data length: {len(x_upsampled)}")
# print(f"New sampling period: {Ts_upsampled:.2e} s")
# print(f"New sampling rate: {1/Ts_upsampled/1e9:.1f} GHz")

# # Create DataFrame for the upsampled data
# df_upsampled = pd.DataFrame({
#     'time': t_upsampled,
#     'data': x_upsampled
# })

# print(f"Upsampled data shape: {df_upsampled.shape}")
# print(f"First few rows of {config.signal.up_fs_num} GHz data:")
# print(df_upsampled.head())

# # Verify the upsampling quality by checking the time axis
# print(f"\nTime axis verification:")
# print(f"Original max time: {df['time'].max():.2e} s")
# print(f"Upsampled max time: {t_upsampled.max():.2e} s")
# print(f"Time axis ratio: {t_upsampled.max() / df['time'].max():.6f} (should be close to 1.0)")

In [None]:
# upsample_factor = int(config.signal.up_fs_num / config.signal.fs_num)

# x_orig = df['data'].to_numpy(dtype=np.float64)
# x_upsampled = np.zeros(x_orig.shape[0] * upsample_factor)
# x_upsampled[::upsample_factor] = x_orig * upsample_factor

# Ts_orig = 1 / config.signal.fs  # Original sampling period
# Ts_upsampled = Ts_orig / upsample_factor  # New sampling period
# t_upsampled = np.arange(len(x_upsampled), dtype=np.float64) * Ts_upsampled

# df_upsampled = pd.DataFrame({
#     'time': t_upsampled,
#     'data': x_upsampled
# })

# plt.plot(df_upsampled['time'][:2000], df_upsampled['data'][:2000])
# plt.show()

In [None]:
upsample_factor = int(config.signal.up_fs_num / config.signal.fs_num)

# Plot comparison between original and upsampled data
plt.figure(figsize=(15, 10))

# Plot 1: First 1e-8 seconds comparison
plt.subplot(2, 1, 1)
time_limit = 1e-8

# fs_num GHz data
mask_orig = df_awgn['time'] <= time_limit
plt.plot(df_awgn.loc[mask_orig, 'time'], df_awgn.loc[mask_orig, 'data'], 
         'o-', markersize=3, linewidth=1, label=f'{config.signal.fs_num} GHz (original)', alpha=0.7)

# up_fs_num GHz data
mask_upsampled = df_upsampled['time'] <= time_limit
plt.plot(df_upsampled.loc[mask_upsampled, 'time'], df_upsampled.loc[mask_upsampled, 'data'], 
         '-', linewidth=0.8, label=f'{config.signal.up_fs_num} GHz (upsampled)', alpha=0.9)

plt.xlabel('Time (s)')
plt.ylabel('Data')
plt.title(f'Comparison: {config.signal.fs_num} GHz vs {config.signal.up_fs_num} GHz Data (First 10 ns)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.xlim(0, time_limit)
plt.xticks(np.arange(0, time_limit + 2e-9, 2e-9))

# Plot 2: Zoomed view of one pulse
plt.subplot(2, 1, 2)
time_start = 0
time_end = 2e-9  # First 2 ns only

mask_orig_zoom = (df_awgn['time'] >= time_start) & (df_awgn['time'] <= time_end)
mask_upsampled_zoom = (df_upsampled['time'] >= time_start) & (df_upsampled['time'] <= time_end)

plt.plot(df_awgn.loc[mask_orig_zoom, 'time'], df_awgn.loc[mask_orig_zoom, 'data'], 
         'o-', markersize=4, linewidth=1.5, label=f'{config.signal.fs_num} GHz (original)')

plt.plot(df_upsampled.loc[mask_upsampled_zoom, 'time'], df_upsampled.loc[mask_upsampled_zoom, 'data'], 
         '-', linewidth=1, label=f'{config.signal.up_fs_num} GHz (upsampled)')

plt.xlabel('Time (s)')
plt.ylabel('Data')
plt.title('Zoomed View: Single Pulse (First 2 ns)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.xlim(time_start, time_end)

plt.tight_layout()
plt.show()

# Save the up_fs_num GHz data to CSV
df_upsampled.to_csv(config.csv.up_fs, index=False)
print(f"\n{config.signal.up_fs_num} GHz upsampled data saved to: {config.csv.up_fs}")
print(f"File size: ~{len(df_upsampled) * 2 * 8 / 1024**2:.1f} MB")

# Calculate some quality metrics
print(f"\nUpsampling quality check:")
print(f"Original data points: {len(df_awgn)}")
print(f"Upsampled data points: {len(df_upsampled)} (should be {len(df_awgn) * upsample_factor})")
print(f"Ratio: {len(df_upsampled) / len(df_awgn):.1f} (should be {upsample_factor:.1f})")

# Check if the upsampled data preserves the original samples
original_indices = np.arange(0, len(df_upsampled), upsample_factor)
if len(original_indices) <= len(df_awgn):
    max_diff = np.max(np.abs(df_upsampled.iloc[original_indices]['data'].values[:len(df_awgn)] - df_awgn['data'].values))
    print(f"Maximum difference at original sample points: {max_diff:.2e} (should be ~0)")


In [None]:
df = pd.read_csv(config.csv.up_fs)

print(f"Reading file: {config.csv.up_fs}")  # Add this line to confirm which file is being read

# Plot Data column vs Time
plt.figure(figsize=(10, 6))
plt.plot(df['time'][:10000], df['data'][:10000])
plt.xlabel('Time (s)')
plt.ylabel('Data')
plt.title('Data vs Time')
plt.grid(True)
plt.show()

# Calculate the square of the data
df['Data_Squared'] = df['data'] ** 2
plt.figure(figsize=(10, 6))
plt.plot(df['time'][:10000], df['Data_Squared'][:10000])
plt.xlabel('Time (s)')
plt.ylabel('Data Squared')
plt.title('Data Squared vs Time')
plt.grid(True)
plt.show()

In [None]:
# Butterworth low-pass filter parameters
native_sampling_rate = config.signal.up_fs  # Use the upsampled frequency
cutoff = 2e9  # 4 GHz cutoff frequency
N = 4  # Filter order
nyq = native_sampling_rate / 2
cutoff_norm = cutoff / nyq

# Design Butterworth filter
b, a = sp.signal.butter(N, cutoff_norm, btype='low')

# Apply zero-phase filtering
df['Data_MA'] = sp.signal.filtfilt(b, a, df['Data_Squared'])

plt.figure(figsize=(18, 6))
plt.plot(df['time'][:10000], df['Data_Squared'][:10000], label='Data Squared', alpha=0.7)
plt.plot(df['time'][:10000], df['Data_MA'][:10000], label=f'Butterworth LPF (N={N})', linewidth=2)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Butterworth Low-pass Filter Effect')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
from scipy import signal
import numpy as np
import matplotlib.pyplot as plt

# Set threshold
threshold = 0.1
native_sampling_rate = config.signal.up_fs  # Use the upsampled frequency
symbol_rate = config.signal.fsymbol  # 500 MHz

# Get data
filtered_data = df["Data_MA"].values

# Calculate both rising and falling edges
above = filtered_data > threshold
rising_edges = np.where(np.diff(above.astype(int)) == 1)[0] + 1   # Rising edges
falling_edges = np.where(np.diff(above.astype(int)) == -1)[0] + 1  # Falling edges

print(f"Found {len(rising_edges)} rising edges")
print(f"Found {len(falling_edges)} falling edges")

# Calculate midpoints between rising and falling edges
mid_indices = []
mid_times = []

if rising_edges.size > 0 and falling_edges.size > 0:
    # Find pairs of rising and falling edges to calculate midpoints
    i_rising = 0
    i_falling = 0
    
    while i_rising < len(rising_edges) and i_falling < len(falling_edges):
        rise_idx = rising_edges[i_rising]
        fall_idx = falling_edges[i_falling]
        
        # Find the next falling edge after the current rising edge
        if fall_idx > rise_idx:
            # Calculate midpoint between this rising edge and falling edge
            midpoint_idx = int((rise_idx + fall_idx) / 2)
            mid_indices.append(midpoint_idx)
            mid_times.append(df.loc[midpoint_idx, 'time'])
            i_rising += 1
            i_falling += 1
        else:
            # This falling edge is before the current rising edge, skip it
            i_falling += 1
    
    mid_indices = np.array(mid_indices)
    mid_times = np.array(mid_times)
    
    print(f"Found {len(mid_indices)} valid midpoints between rising and falling edges")
    if len(mid_times) > 0:
        print(f"First 20 midpoint times: {mid_times[:20]}")
else:
    mid_indices = np.array([])
    mid_times = np.array([])
    print("No rising/falling edge pairs detected")

# Visualization
plt.figure(figsize=(12, 6))
plt.plot(df['time'][:20480*2], df['Data_MA'][:20480*2], label='Filtered Data (Data_MA)')
plt.axhline(y=threshold, color='gray', linestyle='--', alpha=0.5, label=f'Threshold ({threshold})')

if rising_edges.size > 0:
    plt.plot(df.loc[rising_edges, 'time'][:20], filtered_data[rising_edges][:20], '^g', markersize=8, label='Rising Edge')
if falling_edges.size > 0:
    plt.plot(df.loc[falling_edges, 'time'][:20], filtered_data[falling_edges][:20], 'vr', markersize=8, label='Falling Edge')
if mid_indices.size > 0:
    plt.plot(df.loc[mid_indices, 'time'][:20], filtered_data[mid_indices][:20], 'ko', markersize=6, label='Midpoints')

plt.xlabel('Time (s)')
plt.ylabel('Filtered Data (Moving Average)')
plt.title('Rising/Falling Edges and Their Midpoints')
plt.grid(True)
plt.legend()
plt.show()


In [None]:
# # Fold the lowpass filtered data by symbol duration
# # and check rising/falling edges for each symbol

# print("=== Folding Low-pass Filtered Data by Symbol Duration ===")

# # Get the symbol duration in samples
# symbol_duration_samples = int(config.signal.up_fs * config.signal.symbol_duration)
# print(f"Symbol duration: {config.signal.symbol_duration:.2e} s")
# print(f"Symbol duration in samples: {symbol_duration_samples}")

# # Get the filtered data
# filtered_data = df['Data_MA'].values
# data_length = len(filtered_data)

# print(f"Total data length: {data_length} samples")
# print(f"Expected symbols: {data_length / symbol_duration_samples:.2f}")

# # Calculate number of complete symbols and trim data if necessary
# num_symbols = data_length // symbol_duration_samples
# trimmed_length = num_symbols * symbol_duration_samples

# if data_length % symbol_duration_samples != 0:
#     print(f"⚠ Data length ({data_length}) is not divisible by symbol duration ({symbol_duration_samples})")
#     print(f"Trimming data to {trimmed_length} samples ({num_symbols} complete symbols)")
#     filtered_data = filtered_data[:trimmed_length]
#     data_length = trimmed_length

# print(f"Using data length: {data_length} samples")
# print(f"Number of complete symbols: {num_symbols}")

# # Now assert that the trimmed data length is exactly divisible by symbol duration
# assert data_length % symbol_duration_samples == 0, \
#     f"Trimmed data length ({data_length}) is not divisible by symbol duration ({symbol_duration_samples})"

# # Fold the data - reshape to (num_symbols, symbol_duration_samples)
# folded_data = filtered_data.reshape(num_symbols, symbol_duration_samples)
# print(f"Folded data shape: {folded_data.shape}")

# # Check rising and falling edges for each symbol
# threshold = 0.1  # Same threshold as before
# rising_edges_per_symbol = []
# falling_edges_per_symbol = []
# midpoints_per_symbol = []

# for symbol_idx in range(num_symbols):
#     symbol_data = folded_data[symbol_idx]
    
#     # Find rising and falling edges within this symbol
#     above = symbol_data > threshold
#     rising_edges = np.where(np.diff(above.astype(int)) == 1)[0] + 1
#     falling_edges = np.where(np.diff(above.astype(int)) == -1)[0] + 1
    
#     rising_edges_per_symbol.append(rising_edges)
#     falling_edges_per_symbol.append(falling_edges)
    
#     # Calculate midpoint for this symbol (400 samples left of falling edge)
#     if len(falling_edges) > 0:
#         midpoint = falling_edges[0] - 400
#         if midpoint >= 0:
#             midpoints_per_symbol.append(midpoint)
#         else:
#             midpoints_per_symbol.append(0)  # If midpoint would be negative, use 0
#     else:
#         midpoints_per_symbol.append(None)  # No falling edge found

# # Print statistics
# print(f"\n=== Edge Detection Statistics ===")
# rising_counts = [len(edges) for edges in rising_edges_per_symbol]
# falling_counts = [len(edges) for edges in falling_edges_per_symbol]

# print(f"Rising edges per symbol - Min: {min(rising_counts)}, Max: {max(rising_counts)}, Mean: {np.mean(rising_counts):.2f}")
# print(f"Falling edges per symbol - Min: {min(falling_counts)}, Max: {max(falling_counts)}, Mean: {np.mean(falling_counts):.2f}")

# # Count symbols with exactly 1 rising and 1 falling edge
# perfect_symbols = sum(1 for i in range(num_symbols) if rising_counts[i] == 1 and falling_counts[i] == 1)
# print(f"Symbols with exactly 1 rising and 1 falling edge: {perfect_symbols}/{num_symbols} ({perfect_symbols/num_symbols*100:.1f}%)")

# # Assert that each symbol has exactly one rising and one falling edge
# try:
#     for i in range(num_symbols):
#         assert rising_counts[i] == 1, f"Symbol {i} has {rising_counts[i]} rising edges (expected 1)"
#         assert falling_counts[i] == 1, f"Symbol {i} has {falling_counts[i]} falling edges (expected 1)"
#     print("✓ All symbols have exactly 1 rising and 1 falling edge")
# except AssertionError as e:
#     print(f"⚠ Warning: {e}")
#     print("Not all symbols have exactly 1 rising and 1 falling edge")

# # Visualize the first few symbols
# plt.figure(figsize=(15, 10))

# # Plot first 8 symbols
# for i in range(min(8, num_symbols)):
#     plt.subplot(2, 4, i+1)
#     symbol_data = folded_data[i]
#     x_axis = np.arange(len(symbol_data))
    
#     plt.plot(x_axis, symbol_data, 'b-', alpha=0.7, label='Data')
#     plt.axhline(y=threshold, color='r', linestyle='--', alpha=0.5, label='Threshold')
    
#     # Plot rising edges
#     if len(rising_edges_per_symbol[i]) > 0:
#         for edge in rising_edges_per_symbol[i]:
#             plt.axvline(x=edge, color='g', linestyle='-', alpha=0.7, label='Rising' if edge == rising_edges_per_symbol[i][0] else '')
    
#     # Plot falling edges
#     if len(falling_edges_per_symbol[i]) > 0:
#         for edge in falling_edges_per_symbol[i]:
#             plt.axvline(x=edge, color='r', linestyle='-', alpha=0.7, label='Falling' if edge == falling_edges_per_symbol[i][0] else '')
    
#     # Plot midpoint
#     if midpoints_per_symbol[i] is not None:
#         plt.axvline(x=midpoints_per_symbol[i], color='k', linestyle=':', alpha=0.8, label='Midpoint')
    
#     plt.title(f'Symbol {i}\nR:{len(rising_edges_per_symbol[i])}, F:{len(falling_edges_per_symbol[i])}')
#     plt.grid(True, alpha=0.3)
#     if i == 0:
#         plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# plt.tight_layout()
# plt.show()

# # Store results for later use
# print(f"\n=== Results Summary ===")
# print(f"Folded data shape: {folded_data.shape}")
# print(f"Midpoints per symbol: {len([mp for mp in midpoints_per_symbol if mp is not None])}/{num_symbols}")
# print(f"First 10 midpoints: {midpoints_per_symbol[:10]}")

In [None]:
# Get list of all mid positions and use K-means clustering
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans

# Configuration parameters
N_CLUSTERS = 2  # Number of clusters for K-means analysis

# Extract mid positions from the previous analysis
if 'mid_indices' in locals() and 'mid_times' in locals():
    # List of mid indices in the dataframe
    mid_indices_list = mid_indices.tolist()
    
    # List of mid time positions
    mid_positions = mid_times.tolist()
    
    print("Mid Point Analysis Results:")
    print(f"Number of mid points found: {len(mid_positions)}")
    print(f"Mid indices: {mid_indices_list}")
    print(f"Mid time positions (seconds): {mid_positions}")
    
    # Convert to more readable format (nanoseconds)
    mid_positions_ns = [pos * 1e9 for pos in mid_positions]
    print(f"Mid time positions (nanoseconds): {mid_positions_ns}")
    
    # Calculate relative positions from the first mid point
    if len(mid_positions) > 1:
        relative_positions = [(pos - mid_positions[0]) * 1e9 for pos in mid_positions]
        print(f"Relative positions from first mid point (ns): {relative_positions}")
        
        # Calculate intervals between consecutive mid points
        intervals_ns = [pos * 1e9 for pos in np.diff(mid_positions)]
        print(f"Intervals between consecutive mid points (ns): {intervals_ns}")
        
    # Use K-means clustering on mid positions (modulo 2ns for periodic analysis)
    if len(mid_positions_ns) > N_CLUSTERS:  # Need at least N_CLUSTERS points for clustering
        mid_positions_mod = np.array([pos % 2 for pos in mid_positions_ns]).reshape(-1, 1)
        
        # Apply K-means with N_CLUSTERS clusters
        kmeans = KMeans(n_clusters=N_CLUSTERS, random_state=42)
        cluster_labels_original = kmeans.fit_predict(mid_positions_mod)
        cluster_centers = kmeans.cluster_centers_.flatten()
        
        # Create mapping to sort clusters by time position
        center_time_pairs = [(i, center) for i, center in enumerate(cluster_centers)]
        center_time_pairs.sort(key=lambda x: x[1])  # Sort by time position
        
        # Create mapping from original cluster ID to time-sorted cluster ID
        old_to_new_mapping = {}
        for new_id, (old_id, _) in enumerate(center_time_pairs):
            old_to_new_mapping[old_id] = new_id
        
        # Apply the mapping to cluster labels
        cluster_labels = np.array([old_to_new_mapping[label] for label in cluster_labels_original])
        
        # Sort cluster centers by time position
        sorted_centers = [pair[1] for pair in center_time_pairs]
        
        print(f"\n=== K-means Clustering Results ({N_CLUSTERS} clusters) - Time Sorted ===")
        for i, center in enumerate(sorted_centers):
            cluster_size = np.sum(cluster_labels == i)
            print(f"Cluster {i}: Time = {center:.4f} ns (modulo 2), Points = {cluster_size}")
        
        # Convert cluster centers to picoseconds and display as list
        print(f"\nCluster centers in picoseconds (time sorted): [", end="")
        for i in range(len(sorted_centers)):
            print(f"{sorted_centers[i] * 1000:.2f}", end=" " if i < len(sorted_centers) - 1 else "")
        print("]")

        # Print a nanosecond list
        print(f"\nCluster centers in nanoseconds (time sorted): [", end="")
        for i in range(len(sorted_centers)):
            print(f"{sorted_centers[i]:.8f}", end=" " if i < len(sorted_centers) - 1 else "")
        print("]")

        # Calculate cluster positions (for 2 clusters)
        first_cluster_center = sorted_centers[0]
        second_cluster_center = sorted_centers[1]
        
        print(f"\n=== Cluster Position Analysis ===")
        print(f"Position of cluster 0: {first_cluster_center:.6f} ns")
        print(f"Position of cluster 1: {second_cluster_center:.6f} ns")
        print(f"Difference between the two clusters: {second_cluster_center - first_cluster_center:.6f} ns")
        
        # Also show in picoseconds
        print(f"\nPosition of cluster 0: {first_cluster_center * 1000:.2f} ps")
        print(f"Position of cluster 1: {second_cluster_center * 1000:.2f} ps")
        print(f"Difference between the two clusters: {(second_cluster_center - first_cluster_center) * 1000:.2f} ps")
        
        # For compatibility with downstream code, assign the cluster centers to the old variable names
        first_8_average = first_cluster_center
        last_8_average = second_cluster_center

        # Visualize clustering results
        plt.figure(figsize=(12, 8))
        # Generate colors for clusters
        colors = ['#1f77b4', '#ff7f0e']

        # Plot cluster centers (now time-sorted)
        plt.scatter(sorted_centers, range(N_CLUSTERS), color='red', marker='x', s=150, 
                   linewidths=3, label='Cluster Centers')
        
        # Plot cluster points
        for i in range(N_CLUSTERS):
            cluster_points = mid_positions_mod[cluster_labels == i]
            plt.scatter(cluster_points, np.ones(len(cluster_points)) * i, 
                       color=colors[i % len(colors)], alpha=0.7, s=30, label=f'Cluster {i}')
        
        # Add vertical lines for cluster centers
        plt.axvline(x=float(first_cluster_center), color='blue', linestyle='--', linewidth=3, alpha=0.8, 
                   label=f'Cluster 0 center: {first_cluster_center:.4f} ns')
        plt.axvline(x=float(second_cluster_center), color='orange', linestyle='--', linewidth=3, alpha=0.8, 
                   label=f'Cluster 1 center: {second_cluster_center:.4f} ns')
        
        # Add text annotations for the cluster centers
        plt.text(float(first_cluster_center) + 0.05, N_CLUSTERS/2 - 0.3, f'Cluster 0\nCenter: {first_cluster_center:.4f} ns', 
                rotation=90, verticalalignment='center', fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.3', facecolor='blue', alpha=0.1))
        plt.text(float(second_cluster_center) + 0.05, N_CLUSTERS/2 + 0.3, f'Cluster 1\nCenter: {second_cluster_center:.4f} ns', 
                rotation=90, verticalalignment='center', fontsize=10,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='orange', alpha=0.1))
        
        plt.xlabel('Mid Position (ns, modulo 2)')
        plt.ylabel('Cluster ID (Time Sorted)')
        plt.title(f'K-means Clustering of Mid Positions ({N_CLUSTERS} clusters, Time Sorted)')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
    else:
        print(f"Not enough mid points for K-means clustering (found {len(mid_positions_ns)}, need > {N_CLUSTERS})")
        
else:
    print("Please run the mid point detection cell first to generate mid point data.")

In [None]:
# Simplified classification logic: directly use 1ns as boundary
import numpy as np
import matplotlib.pyplot as plt

# Configuration parameters
THRESHOLD_NS = 1.0  # 1ns threshold for classification

# Extract mid positions from the previous analysis
if 'mid_indices' in locals() and 'mid_times' in locals():
    # List of mid indices in the dataframe
    mid_indices_list = mid_indices.tolist()
    
    # List of mid time positions
    mid_positions = mid_times.tolist()
    
    print("Mid Point Analysis Results:")
    print(f"Number of mid points found: {len(mid_positions)}")
    print(f"Mid indices: {mid_indices_list}")
    print(f"Mid time positions (seconds): {mid_positions}")
    
    # Convert to more readable format (nanoseconds)
    mid_positions_ns = [pos * 1e9 for pos in mid_positions]
    print(f"Mid time positions (nanoseconds): {mid_positions_ns}")
    
    # Calculate relative positions from the first mid point
    if len(mid_positions) > 1:
        relative_positions = [(pos - mid_positions[0]) * 1e9 for pos in mid_positions]
        print(f"Relative positions from first mid point (ns): {relative_positions}")
        
        # Calculate intervals between consecutive mid points
        intervals_ns = [pos * 1e9 for pos in np.diff(mid_positions)]
        print(f"Intervals between consecutive mid points (ns): {intervals_ns}")
        
    # Simplified classification: using fixed threshold 1ns
    if len(mid_positions_ns) > 0:
        # Calculate positions after modulo 2ns
        mid_positions_mod = np.array([pos % 2 for pos in mid_positions_ns])
        
        # Directly use 1ns as threshold for classification
        cluster_labels = np.where(mid_positions_mod < THRESHOLD_NS, 0, 1)
        
        # Calculate center position of each cluster
        cluster_0_positions = mid_positions_mod[cluster_labels == 0]
        cluster_1_positions = mid_positions_mod[cluster_labels == 1]
        
        if len(cluster_0_positions) > 0:
            # Calculate the average of 5th and 95th percentiles as cluster center
            p5 = np.percentile(cluster_0_positions, 5)
            p95 = np.percentile(cluster_0_positions, 95)
            first_cluster_center = (p5 + p95) / 2
        else:
            first_cluster_center = 0.5  # Default value
            
        if len(cluster_1_positions) > 0:
            # Calculate the average of 5th and 95th percentiles as cluster center
            p5 = np.percentile(cluster_1_positions, 5)
            p95 = np.percentile(cluster_1_positions, 95)
            second_cluster_center = (p5 + p95) / 2
        else:
            second_cluster_center = 1.5  # Default value
            
        sorted_centers = [first_cluster_center, second_cluster_center]
        
        print(f"\n=== Simplified Classification Results (Threshold: {THRESHOLD_NS} ns) ===")
        cluster_0_count = np.sum(cluster_labels == 0)
        cluster_1_count = np.sum(cluster_labels == 1)
        print(f"Cluster 0: Time = {first_cluster_center:.4f} ns (modulo 2), Points = {cluster_0_count}")
        print(f"Cluster 1: Time = {second_cluster_center:.4f} ns (modulo 2), Points = {cluster_1_count}")
        
        # Maintain output format compatible with original code
        print(f"\nCluster centers in picoseconds (time sorted): [{first_cluster_center * 1000:.2f} {second_cluster_center * 1000:.2f}]")
        print(f"\nCluster centers in nanoseconds (time sorted): [{first_cluster_center:.8f} {second_cluster_center:.8f}]")

        print(f"\n=== Cluster Position Analysis ===")
        print(f"Position of cluster 0: {first_cluster_center:.6f} ns")
        print(f"Position of cluster 1: {second_cluster_center:.6f} ns")
        print(f"Difference between the two clusters: {second_cluster_center - first_cluster_center:.6f} ns")
        
        # Also show in picoseconds
        print(f"\nPosition of cluster 0: {first_cluster_center * 1000:.2f} ps")
        print(f"Position of cluster 1: {second_cluster_center * 1000:.2f} ps")
        print(f"Difference between the two clusters: {(second_cluster_center - first_cluster_center) * 1000:.2f} ps")

        # Visualize classification results
        plt.figure(figsize=(12, 8))
        # Generate colors for clusters
        colors = ['#1f77b4', '#ff7f0e']

        # Plot cluster centers
        plt.scatter(sorted_centers, range(2), color='red', marker='x', s=150, 
                   linewidths=3, label='Cluster Centers')
        
        # Plot cluster points
        for i in range(2):
            cluster_points = mid_positions_mod[cluster_labels == i]
            plt.scatter(cluster_points, np.ones(len(cluster_points)) * i, 
                       color=colors[i % len(colors)], alpha=0.7, s=30, label=f'Cluster {i}')
        
        # Add vertical line for threshold
        plt.axvline(x=THRESHOLD_NS, color='green', linestyle=':', linewidth=2, alpha=0.8, 
                   label=f'Threshold: {THRESHOLD_NS} ns')
        
        # Add vertical lines for cluster centers
        plt.axvline(x=float(first_cluster_center), color='blue', linestyle='--', linewidth=3, alpha=0.8, 
                   label=f'Cluster 0 center: {first_cluster_center:.4f} ns')
        plt.axvline(x=float(second_cluster_center), color='orange', linestyle='--', linewidth=3, alpha=0.8, 
                   label=f'Cluster 1 center: {second_cluster_center:.4f} ns')
        
        # Add text annotations for the cluster centers
        plt.text(float(first_cluster_center) + 0.05, 0.7, f'Cluster 0\nCenter: {first_cluster_center:.4f} ns', 
                rotation=90, verticalalignment='center', fontsize=10, 
                bbox=dict(boxstyle='round,pad=0.3', facecolor='blue', alpha=0.1))
        plt.text(float(second_cluster_center) + 0.05, 1.3, f'Cluster 1\nCenter: {second_cluster_center:.4f} ns', 
                rotation=90, verticalalignment='center', fontsize=10,
                bbox=dict(boxstyle='round,pad=0.3', facecolor='orange', alpha=0.1))
        
        plt.xlabel('Mid Position (ns, modulo 2)')
        plt.ylabel('Cluster ID')
        plt.title(f'Simplified Classification Results (Threshold: {THRESHOLD_NS} ns)')
        plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        plt.grid(True, alpha=0.3)
        plt.tight_layout()
        plt.show()
        
    else:
        print(f"No mid points found for classification")
        
else:
    print("Please run the mid point detection cell first to generate mid point data.")


In [None]:
# Symbol classification and alignment for 2-cluster PPM demodulation

# Symbol classification based on 2 clusters
print(f"\n=== Symbol Classification ===")

# Direct classification: cluster 0 = bit 0, cluster 1 = bit 1
symbol_classifications = cluster_labels

# Data reshaping and symbol alignment
print(f"\n=== Data Reshaping and Symbol Alignment ===")

# Get original signal data and parameters
ori_signal = df['data'].values
ori_len = len(ori_signal)
samples_per_frame = int(config.signal.up_fs / config.signal.fsymbol)

# Reshape data into frames
trimmed_len = ori_len - (ori_len % samples_per_frame)
trimmed_signal = ori_signal[:trimmed_len]
frames = trimmed_signal.reshape(-1, samples_per_frame)

print(f"Original signal length: {ori_len}")
print(f"Samples per symbol: {samples_per_frame}")
print(f"Trimmed signal length: {trimmed_len}")
print(f"Frames shape: {frames.shape}")

# Calculate symbol center and offsets
symbol_center = samples_per_frame // 2

# Convert cluster positions from nanoseconds to sample indices
sampling_period = 1 / config.signal.up_fs  # seconds per sample
cluster_0_samples = first_cluster_center * 1e-9 / sampling_period  # convert ns to samples
cluster_1_samples = second_cluster_center * 1e-9 / sampling_period   # convert ns to samples

# Calculate shifts needed to center each cluster
x0 = int(round(symbol_center - cluster_0_samples))  # shift for cluster 0 (right shift if positive)
x1 = int(round(symbol_center - cluster_1_samples))  # shift for cluster 1 (right shift if positive)
# print(x0, x1)  # 214， -170
# x0, x1 = 214, -174

print(f"Symbol center position: {symbol_center}")
print(f"Cluster 0 position (samples): {cluster_0_samples:.2f}")
print(f"Cluster 1 position (samples): {cluster_1_samples:.2f}")
print(f"Cluster 0 needs to move: {x0} samples ({'right shift' if x0 > 0 else 'left shift' if x0 < 0 else 'no movement'})")
print(f"Cluster 1 needs to move: {x1} samples ({'right shift' if x1 > 0 else 'left shift' if x1 < 0 else 'no movement'})")

# Apply shifts to each frame based on classification
aligned_frames = np.zeros_like(frames)

for i in range(len(frames)):
    if i < len(symbol_classifications):  # Ensure we have classification for this frame
        if symbol_classifications[i] == 0:  # Cluster 0
            shift = x0
        else:  # Cluster 1
            shift = x1
        
        # Apply the calculated shift
        if shift > 0:  # Right shift
            aligned_frames[i, shift:] = frames[i, :-shift]
            aligned_frames[i, :shift] = 0  # Fill with zeros
        elif shift < 0:  # Left shift
            aligned_frames[i, :shift] = frames[i, -shift:]
            aligned_frames[i, shift:] = 0  # Fill with zeros
        else:  # No shift needed
            aligned_frames[i] = frames[i]
    else:
        aligned_frames[i] = frames[i]  # No classification available, keep original

print(f"Symbol alignment completed, aligned frames shape: {aligned_frames.shape}")

# Visualize the alignment effect
plt.figure(figsize=(15, 10))

# Show first few symbols before and after alignment
num_examples = min(8, len(frames))

for i in range(num_examples):
    # Original frames
    plt.subplot(2, num_examples, i+1)
    plt.plot(frames[i], 'b-', alpha=0.7, label='Original')
    plt.axvline(x=symbol_center, color='gray', linestyle='--', alpha=0.5, label='Center')
    if i < len(symbol_classifications):
        cluster_name = f'Cluster {symbol_classifications[i]}'
        plt.title(f'Original Symbol {i}\n({cluster_name})')
    plt.grid(True, alpha=0.3)
    if i == 0:
        plt.legend()
    
    # Aligned frames
    plt.subplot(2, num_examples, i+1+num_examples)
    plt.plot(aligned_frames[i], 'r-', alpha=0.7, label='Aligned')
    plt.axvline(x=symbol_center, color='gray', linestyle='--', alpha=0.5, label='Center')
    if i < len(symbol_classifications):
        cluster_name = f'Cluster {symbol_classifications[i]}'
        shift = x0 if symbol_classifications[i] == 0 else x1
        shift_info = f'Move {shift}' if shift != 0 else 'No movement'
        plt.title(f'Aligned Symbol {i}\n({cluster_name}, {shift_info})')
    plt.grid(True, alpha=0.3)
    if i == 0:
        plt.legend()

plt.suptitle('Symbol Alignment: Before vs After (2-Cluster PPM)')
plt.tight_layout()
plt.show()

# Count symbols in each cluster
cluster_0_count = np.sum(symbol_classifications == 0)
cluster_1_count = np.sum(symbol_classifications == 1)
total_symbols = len(symbol_classifications)

print(f"Total detected symbols: {total_symbols}")
print(f"Cluster 0: {cluster_0_count} symbols ({cluster_0_count/total_symbols*100:.1f}%)")
print(f"Cluster 1: {cluster_1_count} symbols ({cluster_1_count/total_symbols*100:.1f}%)")

# Show first 50 classifications
print(f"\nClassification results for first 50 symbols:")
print(f"Symbol indices: {list(range(min(50, len(symbol_classifications))))}")
print(f"Classifications: {symbol_classifications[:50].tolist()}")

# Create a visualization of symbol classifications over time
plt.figure(figsize=(15, 8))

# Plot 1: Symbol classification over time
plt.subplot(2, 1, 1)
plt.plot(range(len(symbol_classifications)), symbol_classifications, 'o-', markersize=3, alpha=0.7)
plt.xlabel('Symbol Index')
plt.ylabel('Classification (0=Cluster 0, 1=Cluster 1)')
plt.title('2-PPM Symbol Classification Over Time')
plt.grid(True, alpha=0.3)
plt.yticks([0, 1], ['Cluster 0', 'Cluster 1'])

# Plot 2: Histogram of classifications
plt.subplot(2, 1, 2)
plt.bar(['Cluster 0', 'Cluster 1'], 
        [cluster_0_count, cluster_1_count], 
        color=['blue', 'orange'], alpha=0.7)
plt.ylabel('Symbol Count')
plt.title('Symbol Distribution by Cluster')
plt.grid(True, alpha=0.3)

# Add count labels on bars
plt.text(0, cluster_0_count + total_symbols*0.01, f'{cluster_0_count}\n({cluster_0_count/total_symbols*100:.1f}%)', 
        ha='center', va='bottom')
plt.text(1, cluster_1_count + total_symbols*0.01, f'{cluster_1_count}\n({cluster_1_count/total_symbols*100:.1f}%)', 
        ha='center', va='bottom')

plt.tight_layout()
plt.show()

# Calculate bit sequence (direct mapping: cluster 0 -> bit 0, cluster 1 -> bit 1)
bit_sequence = symbol_classifications
print(f"\nDemodulated bit sequence (first 100 bits): {bit_sequence[:100].tolist()}")

# Calculate transition statistics
transitions = np.diff(symbol_classifications)
num_transitions = np.sum(transitions != 0)
print(f"\nStatistics:")
print(f"Total transitions: {num_transitions}")
print(f"Transition rate: {num_transitions/(len(symbol_classifications)-1)*100:.1f}%")

## IQ demodulation

In [None]:
samples_per_frame = round(config.signal.up_fs / config.signal.fsymbol)

t_one_frame = np.linspace(0, 1/config.signal.fsymbol, samples_per_frame, endpoint=False)
cosine_wave = np.cos(2 * np.pi * config.signal.fc * t_one_frame)
sine_wave = np.sin(2 * np.pi * config.signal.fc * t_one_frame)

ori_signal = df['data'].values
ori_len = df['data'].size

demod_signal = cosine_wave * aligned_frames + sine_wave * aligned_frames * 1j
demod_signal = demod_signal.flatten()

# Plot the demodulated signal (first 10000 samples)
plt.figure(figsize=(12, 6))
plt.plot(df['time'][:10000], demod_signal.real[:10000], label='Demodulated Signal (Real Part)', alpha=0.7)
plt.plot(df['time'][:10000], demod_signal.imag[:10000], label='Demodulated Signal (Imaginary Part)', alpha=0.7)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Demodulated Signal from 8PPM D2PSK (First 10000 samples)')
plt.grid(True, alpha=0.3)
plt.legend()
plt.show()

In [None]:
# Apply lowpass filter to the demodulated signal
from scipy import signal

# Filter parameters
cutoff_freq = 2e9  # 2 GHz cutoff frequency
sampling_rate = config.signal.up_fs  # Use the upsampled frequency
nyq = sampling_rate / 2
normalized_cutoff = cutoff_freq / nyq

# Design Butterworth lowpass filter
filter_order = 4
b, a = signal.butter(filter_order, normalized_cutoff, btype='low')

# Apply zero-phase filtering to both real and imaginary parts
demod_signal_filtered = np.zeros_like(demod_signal, dtype=complex)
demod_signal_filtered.real = signal.filtfilt(b, a, demod_signal.real)
demod_signal_filtered.imag = signal.filtfilt(b, a, demod_signal.imag)

print(f"Applied lowpass filter with cutoff frequency: {cutoff_freq/1e9:.1f} GHz")
print(f"Filter order: {filter_order}")
print(f"Sampling rate: {sampling_rate/1e9:.1f} GHz")
print(f"Normalized cutoff: {normalized_cutoff:.4f}")

# Plot comparison between original and filtered demodulated signal
plt.figure(figsize=(15, 8))

# Plot 1: Real part comparison
plt.subplot(2, 1, 1)
plt.plot(df['time'][:10000], demod_signal.real[:10000], label='Original Demodulated (Real)', alpha=0.7)
plt.plot(df['time'][:10000], demod_signal_filtered.real[:10000], label='Filtered Demodulated (Real)', alpha=0.8)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Demodulated Signal Comparison - Real Part')
plt.grid(True, alpha=0.3)
plt.legend()

# Plot 2: Imaginary part comparison
plt.subplot(2, 1, 2)
plt.plot(df['time'][:10000], demod_signal.imag[:10000], label='Original Demodulated (Imag)', alpha=0.7)
plt.plot(df['time'][:10000], demod_signal_filtered.imag[:10000], label='Filtered Demodulated (Imag)', alpha=0.8)
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')
plt.title('Demodulated Signal Comparison - Imaginary Part')
plt.grid(True, alpha=0.3)
plt.legend()

plt.tight_layout()
plt.show()

In [None]:
# Visualize df_upsampled samples belonging to cluster 8
import numpy as np
import matplotlib.pyplot as plt

# Check if clustering results exist
if 'cluster_labels' in locals() and 'mid_indices' in locals():
    # Find samples belonging to cluster 8
    cluster_8_mask = cluster_labels == 7
    cluster_8_mid_indices = mid_indices[cluster_8_mask]

    cluster_8_mid_indices = [idx // 2048 * 2048 + 1024 for idx in cluster_8_mid_indices]
    
    print(f"Found {len(cluster_8_mid_indices)} samples belonging to cluster 8")
    
    if len(cluster_8_mid_indices) > 0:
        # Get number of samples per symbol
        samples_per_frame = int(config.signal.up_fs / config.signal.fsymbol)
        print(f"Samples per frame (symbol): {samples_per_frame}")
        
        # Create visualization
        plt.figure(figsize=(16, 10))
        
        # Show first 8 symbols belonging to cluster 8
        num_examples = min(8, len(cluster_8_mid_indices))
        
        for i in range(num_examples):
            mid_idx = cluster_8_mid_indices[i]
            
            # Estimate symbol start and end indices
            # Assume mid_idx is near the middle of the symbol
            symbol_start = max(0, mid_idx - samples_per_frame // 2)
            symbol_end = min(len(df_upsampled), symbol_start + samples_per_frame)
            
            # Ensure we have a complete symbol
            if symbol_end - symbol_start == samples_per_frame:
                symbol_data = df_upsampled.iloc[symbol_start:symbol_end]
                
                plt.subplot(2, 4, i+1)
                plt.plot(symbol_data['time'] - symbol_data['time'].iloc[0], symbol_data['data'], 'b-', alpha=0.8)
                plt.axvline(x=(mid_idx - symbol_start) * (1/config.signal.up_fs), 
                           color='red', linestyle='--', alpha=0.7, label='Mid point')
                plt.xlabel('Time (s)')
                plt.ylabel('Amplitude')
                plt.title(f'Cluster 8 Sample {i+1}\\nMid idx: {mid_idx}')
                plt.grid(True, alpha=0.3)
                if i == 0:
                    plt.legend()
        
        plt.suptitle('Samples from Cluster 8 in df_upsampled', fontsize=14)
        plt.tight_layout()
        plt.show()
        
        # Show cluster 8 sample statistics
        print(f"\\nCluster 8 Statistics:")
        print(f"Mid indices: {cluster_8_mid_indices[:10]}...")  # Show only first 10
        print(f"Mid times (ns): {(df_upsampled.loc[cluster_8_mid_indices[:10], 'time'] * 1e9).tolist()}...")
        
        # Calculate average time position of cluster 8 midpoints (mod 2ns)
        cluster_8_times = df_upsampled.loc[cluster_8_mid_indices, 'time'].values
        cluster_8_times_mod = (cluster_8_times * 1e9) % 2  # Convert to ns and mod 2
        avg_time_mod = np.mean(cluster_8_times_mod)
        print(f"Cluster 8 average time position (mod 2ns): {avg_time_mod:.4f} ns")
        
        # Overlay display of multiple cluster 8 symbols
        plt.figure(figsize=(12, 6))
        
        overlay_count = min(10000, len(cluster_8_mid_indices))
        for i in range(overlay_count):
            mid_idx = cluster_8_mid_indices[i]
            symbol_start = max(0, mid_idx - samples_per_frame // 2)
            symbol_end = min(len(df_upsampled), symbol_start + samples_per_frame)
            
            if symbol_end - symbol_start == samples_per_frame:
                symbol_data = df_upsampled.iloc[symbol_start:symbol_end]
                time_normalized = symbol_data['time'] - symbol_data['time'].iloc[0]
                plt.plot(time_normalized * 1e9, symbol_data['data'], alpha=0.3, color='blue')
        
        # Calculate average waveform
        all_symbols = []
        for i in range(len(cluster_8_mid_indices)):
            mid_idx = cluster_8_mid_indices[i]
            symbol_start = max(0, mid_idx - samples_per_frame // 2)
            symbol_end = min(len(df_upsampled), symbol_start + samples_per_frame)
            
            if symbol_end - symbol_start == samples_per_frame:
                symbol_data = df_upsampled.iloc[symbol_start:symbol_end]['data'].values
                all_symbols.append(symbol_data)
        
        if all_symbols:
            all_symbols = np.array(all_symbols)
            avg_symbol = np.mean(all_symbols, axis=0)
            time_axis = np.arange(len(avg_symbol)) * (1/config.signal.up_fs) * 1e9  # in ns
            
            # plt.plot(time_axis, avg_symbol, 'red', linewidth=3, label=f'Average of {len(all_symbols)} symbols')
            plt.axvline(x=time_axis[len(avg_symbol)//2], color='red', linestyle='--', 
                       alpha=0.7, label='Expected mid point')
        
        plt.xlabel('Time (ns)')
        plt.ylabel('Amplitude')
        plt.title(f'Overlay of {overlay_count} symbols from Cluster 8')
        plt.grid(True, alpha=0.3)
        plt.legend()
        plt.show()
        
    else:
        print("No samples found for cluster 8")
        
else:
    print("Clustering results not available. Please run K-means clustering analysis first.")


In [None]:
demod_signal_filtered = demod_signal_filtered.reshape(-1, samples_per_frame)

# Extract the sampled signal at the specified indices
sampled_signal = demod_signal_filtered[:, 1024].flatten()  # Take the 1024th frame and flatten it
# Plot the sampled signal
plt.figure(figsize=(15, 8))

# Plot real part
plt.subplot(2, 1, 1)
plt.plot(range(len(sampled_signal)), sampled_signal.real, 'o-', markersize=4)
plt.xlabel('Symbol Index')
plt.ylabel('Real Part')
plt.title('Sampled Demodulated Signal - Real Part')
plt.grid(True, alpha=0.3)

# Plot imaginary part
plt.subplot(2, 1, 2)
plt.plot(range(len(sampled_signal)), sampled_signal.imag, 'o-', markersize=4)
plt.xlabel('Symbol Index')
plt.ylabel('Imaginary Part')
plt.title('Sampled Demodulated Signal - Imaginary Part')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Sampled signal shape: {sampled_signal.shape}")
print(f"Number of symbols: {len(sampled_signal)}")

In [None]:
# Additional analysis: Cluster distribution in 8-PSK constellation space
if 'cluster_labels' in locals() and len(cluster_labels) > 0:
    n_points = min(len(sampled_signal), len(cluster_labels))
    
    # Create a subplot showing cluster statistics
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # 1. Cluster phase distribution
    ax1.set_title('Cluster Phase Distribution')
    cluster_phases = []
    cluster_ids = []
    
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_points = sampled_signal[:n_points][cluster_mask]
        
        if len(cluster_points) > 0:
            phases = np.angle(cluster_points, deg=True)
            cluster_phases.extend(phases)
            cluster_ids.extend([cluster_id] * len(phases))
    
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f',
              '#bcbd22', '#17becf', '#aec7e8', '#ffbb78', '#98df8a', '#ff9896', '#c5b0d5', '#c49c94']
    
    for cluster_id in range(16):
        cluster_mask = np.array(cluster_ids) == cluster_id
        if np.any(cluster_mask):
            phases_for_cluster = np.array(cluster_phases)[cluster_mask]
            ax1.scatter([cluster_id] * len(phases_for_cluster), phases_for_cluster, 
                       alpha=0.6, s=20, c=colors[cluster_id % len(colors)], label=f'Cluster {cluster_id}')
    
    ax1.set_xlabel('Cluster ID')
    ax1.set_ylabel('Phase (degrees)')
    ax1.grid(True, alpha=0.3)
    ax1.set_xticks(range(16))
    
    # 2. Cluster magnitude distribution
    ax2.set_title('Cluster Magnitude Distribution')
    cluster_magnitudes = []
    cluster_ids_mag = []
    
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_points = sampled_signal[:n_points][cluster_mask]
        
        if len(cluster_points) > 0:
            magnitudes = np.abs(cluster_points)
            cluster_magnitudes.extend(magnitudes)
            cluster_ids_mag.extend([cluster_id] * len(magnitudes))
    
    for cluster_id in range(16):
        cluster_mask = np.array(cluster_ids_mag) == cluster_id
        if np.any(cluster_mask):
            mags_for_cluster = np.array(cluster_magnitudes)[cluster_mask]
            ax2.scatter([cluster_id] * len(mags_for_cluster), mags_for_cluster, 
                       alpha=0.6, s=20, c=colors[cluster_id % len(colors)])
    
    ax2.set_xlabel('Cluster ID')
    ax2.set_ylabel('Magnitude')
    ax2.grid(True, alpha=0.3)
    ax2.set_xticks(range(16))
    
    # 3. Cluster size distribution
    ax3.set_title('Number of Points per Cluster')
    cluster_sizes = []
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_sizes.append(np.sum(cluster_mask))
    
    bars = ax3.bar(range(16), cluster_sizes, color=[colors[i % len(colors)] for i in range(16)], alpha=0.7)
    ax3.set_xlabel('Cluster ID')
    ax3.set_ylabel('Number of Points')
    ax3.grid(True, alpha=0.3)
    ax3.set_xticks(range(16))
    
    # Add value labels on bars
    for i, (bar, size) in enumerate(zip(bars, cluster_sizes)):
        if size > 0:
            ax3.text(bar.get_x() + bar.get_width()/2, bar.get_height() + max(cluster_sizes)*0.01,
                    str(size), ha='center', va='bottom', fontsize=8)
    
    # 4. Average phase per cluster (polar plot)
    ax4 = plt.subplot(2, 2, 4, projection='polar')
    ax4.set_title('Average Phase per Cluster\\n(Polar Plot)')
    
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_points = sampled_signal[:n_points][cluster_mask]
        
        if len(cluster_points) > 0:
            avg_phase = np.mean(np.angle(cluster_points))
            avg_magnitude = np.mean(np.abs(cluster_points))
            ax4.scatter(avg_phase, avg_magnitude, 
                       s=100, c=colors[cluster_id % len(colors)], 
                       label=f'Cluster {cluster_id}', alpha=0.8)
    
    # Add reference 8-PSK points on polar plot
    reference_angles = np.linspace(0+np.pi/8, 2*np.pi+np.pi/8, 8, endpoint=False)
    reference_radius = np.mean(np.abs(sampled_signal))
    ax4.scatter(reference_angles, [reference_radius]*8, 
               marker='x', s=150, c='red', linewidth=3, label='8-PSK Reference')
    
    plt.tight_layout()
    plt.show()
    
    # Print detailed statistics
    print("\\n=== Detailed Cluster Analysis ===")
    print("Cluster ID | Points | Avg Phase (°) | Avg Magnitude | Phase Std (°) | Mag Std")
    print("-" * 80)
    
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_points = sampled_signal[:n_points][cluster_mask]
        
        if len(cluster_points) > 0:
            phases = np.angle(cluster_points, deg=True)
            magnitudes = np.abs(cluster_points)
            
            print(f"Cluster {cluster_id:2d} | {len(cluster_points):6d} | {np.mean(phases):9.2f} | "
                  f"{np.mean(magnitudes):11.4f} | {np.std(phases):9.2f} | {np.std(magnitudes):7.4f}")
        else:
            print(f"Cluster {cluster_id:2d} | {0:6d} | {'N/A':>9} | {'N/A':>11} | {'N/A':>9} | {'N/A':>7}")
    
    # Calculate which 8-PSK symbol each cluster most likely represents
    print("\\n=== Cluster to 8-PSK Symbol Mapping ===")
    reference_angles_deg = np.degrees(np.linspace(0+np.pi/8, 2*np.pi+np.pi/8, 8, endpoint=False))
    
    for cluster_id in range(16):
        cluster_mask = cluster_labels[:n_points] == cluster_id
        cluster_points = sampled_signal[:n_points][cluster_mask]
        
        if len(cluster_points) > 0:
            avg_phase_deg = np.mean(np.angle(cluster_points, deg=True))
            # Find closest 8-PSK symbol
            phase_differences = np.abs(reference_angles_deg - avg_phase_deg)
            closest_symbol = np.argmin(phase_differences)
            min_diff = phase_differences[closest_symbol]
            
            print(f"Cluster {cluster_id:2d}: Average phase {avg_phase_deg:6.1f}° → "
                  f"8-PSK Symbol {closest_symbol} (ref: {reference_angles_deg[closest_symbol]:6.1f}°, "
                  f"diff: {min_diff:5.1f}°)")
    
else:
    print("Cluster labels not available. Please run the K-means clustering analysis first.")


In [None]:
# Create constellation diagram for the sampled signal
plt.figure(figsize=(10, 8))

# Plot constellation points
plt.scatter(sampled_signal.real, sampled_signal.imag, alpha=0.7, s=30, c='blue', edgecolors='black', linewidth=0.5)

# Add grid and labels
plt.grid(True, alpha=0.3)
plt.xlabel('In-phase (I)')
plt.ylabel('Quadrature (Q)')
plt.title('Constellation Diagram of Sampled Signal')
plt.axis('equal')

# Add circle markers for reference (assuming 8-PSK constellation)
# Calculate the radius based on the data
radius = np.mean(np.abs(sampled_signal))
angles = np.linspace(0+np.pi/8, 2*np.pi+np.pi/8, 8, endpoint=False)
reference_points = radius * np.exp(1j * angles)

plt.scatter(reference_points.real, reference_points.imag, 
           marker='x', s=100, c='red', linewidth=2, label='Reference 8-PSK points')

plt.legend()
plt.tight_layout()
plt.show()

print(f"Constellation diagram plotted for {len(sampled_signal)} symbols")
print(f"Mean magnitude: {np.mean(np.abs(sampled_signal)):.4f}")
print(f"Standard deviation of magnitude: {np.std(np.abs(sampled_signal)):.4f}")

# Calculate and display statistics
magnitudes = np.abs(sampled_signal)
phases = np.angle(sampled_signal, deg=True)

print(f"\nSignal statistics:")
print(f"Magnitude - Mean: {np.mean(magnitudes):.4f}, Std: {np.std(magnitudes):.4f}")
print(f"Phase - Mean: {np.mean(phases):.2f}°, Std: {np.std(phases):.2f}°")
print(f"SNR estimate: {20 * np.log10(np.mean(magnitudes) / np.std(magnitudes)):.2f} dB")

In [None]:
# 8-PSK Demodulation Logic
# Map phase angles to corresponding symbols (0-7)

def demodulate_8psk(complex_signal, phase_offset):
    """
    Demodulate 8-PSK signal by mapping phase angles to symbols
    
    Args:
        complex_signal: Complex-valued signal samples
        phase_offset: Phase offset in radians to align constellation
        
    Returns:
        symbols: Array of demodulated symbols (0-7)
    """
    # Calculate phase angles
    phases = np.angle(complex_signal) + phase_offset
    
    # Normalize phases to [0, 2π)
    phases = phases % (2 * np.pi)
    
    # 8-PSK constellation points (0 to 7)
    # Each symbol spans π/4 radians (45 degrees)
    symbol_spacing = 2 * np.pi / 8  # π/4
    
    # Map phases to symbols
    symbols = np.round(phases / symbol_spacing).astype(int) % 8
    
    return symbols

# Apply demodulation to the sampled signal
demodulated_symbols = demodulate_8psk(sampled_signal, phase_offset=np.pi/8)

print(f"Demodulated {len(demodulated_symbols)} symbols")
print(f"Symbol range: {demodulated_symbols.min()} to {demodulated_symbols.max()}")
print(f"First 20 symbols: {demodulated_symbols[:20]}")

# Count symbol occurrences
symbol_counts = np.bincount(demodulated_symbols, minlength=8)
print(f"\nSymbol distribution:")
for i, count in enumerate(symbol_counts):
    print(f"Symbol {i}: {count} occurrences ({count/len(demodulated_symbols)*100:.1f}%)")

# Plot symbol distribution
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.bar(range(8), symbol_counts)
plt.xlabel('Symbol')
plt.ylabel('Count')
plt.title('Symbol Distribution')
plt.grid(True, alpha=0.3)
plt.xticks(range(8))

plt.subplot(1, 2, 2)
plt.plot(range(min(100, len(demodulated_symbols))), demodulated_symbols[:100], 'o-', markersize=4)
plt.xlabel('Symbol Index')
plt.ylabel('Demodulated Symbol')
plt.title('First 100 Demodulated Symbols')
plt.grid(True, alpha=0.3)
plt.yticks(range(8))

plt.tight_layout()
plt.show()

# Calculate phase angles for verification
phases_deg = np.angle(sampled_signal, deg=True)
print(f"\nPhase statistics:")
print(f"Phase range: {phases_deg.min():.1f}° to {phases_deg.max():.1f}°")
print(f"Phase mean: {phases_deg.mean():.1f}°")
print(f"Phase std: {phases_deg.std():.1f}°")

In [None]:
# # Diagnose and fix KeyError: 'data' issue

# # Check current DataFrame column names
# print("df column names:", df.columns.tolist())
# print("df_upsampled column names:", df_upsampled.columns.tolist())

# # Fix original code - change 'data' to 'original_data'
# original_indices = np.arange(0, len(df_upsampled), upsample_factor)
# if len(original_indices) <= len(df):
#     # Use correct column name 'original_data' instead of 'data'
#     max_diff = np.max(np.abs(df_upsampled.iloc[original_indices]['original_data'].values[:len(df)] - df['original_data'].values))
#     print(f"Maximum difference at original sample points: {max_diff:.2e} (should be ~0)")
# else:
#     print("Warning: Not enough upsampled points to compare")
