In [13]:
import os
import re
import torch
import torch.nn as nn
import netCDF4
import numpy as np
import joblib
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from pathlib import Path

In [14]:
class SurfaceTypeUtils:
    surface_type_dict = {
        -1: "Ocean",
        0: "NaN",
        1: "Artifical",
        2: "Barely vegetated",
        3: "Inland water",
        4: "Crop",
        5: "Grass",
        6: "Shrub",
        7: "Forest"
    }
    ddm_antennas = {
        0: 'None',
        1: 'Zenith',
        2: 'LHCP',
        3: 'RHCP',
    }
    
class GeoUtils:
    def __init__(self, world_shapefile_path):
        self.world = gpd.read_file(world_shapefile_path)

    @staticmethod
    def add_seconds(time, seconds):
        timestamp = datetime.strptime(time, "%Y-%m-%d %H:%M:%S")
        new_timestamp = timestamp + timedelta(seconds=seconds)
        return new_timestamp.strftime("%Y-%m-%d %H:%M:%S")

    def is_land(self, lat, lon):
        point = Point(lon, lat)
        return any(self.world.contains(point))

    @staticmethod
    def check_ocean_and_land(lst):
        has_ocean = -1 in lst
        has_land = any(1 <= num <= 7 for num in lst)
        return has_ocean and has_land

    @staticmethod
    def fill_and_filter(arr):
        mask_all_nan = np.all(np.isnan(arr), axis=(2, 3))
        arr_filled = arr.copy()
        for i in range(arr.shape[0]):
            nan_indices = np.where(mask_all_nan[i])[0]
            if len(nan_indices) > 0:
                valid_indices = np.where(~mask_all_nan[i])[0]
                if len(valid_indices) > 0:
                    mean_matrix = np.nanmean(arr[i, valid_indices, :, :], axis=0)
                    arr_filled[i, nan_indices, :, :] = mean_matrix
        mask_discard = np.all(mask_all_nan, axis=1)
        arr_filtered = arr_filled[~mask_discard]
        return arr_filtered, list(np.where(mask_discard.astype(int) == 1)[0])
    
class NetCDFPreprocessor:
    def __init__(self, root_dir, preprocessing_method=str):
        self.root_dir = root_dir
        self.netcdf_file_list = os.listdir(root_dir)
        self.preprocessing_method = preprocessing_method
        if self.preprocessing_method not in ['filtered', 'with_lat_lons', 'unfiltered']:
            raise ValueError("Invalid preprocessing method. Choose from 'filtered', 'with_lat_lons', or 'unfiltered'.")

    @staticmethod
    def check_integrity(f):
        """Check integrity of the netCDF file"""
        if not isinstance(f, netCDF4.Dataset):
            raise ValueError("Input must be a netCDF4.Dataset object")
        if 'L1a_power_ddm' not in f.variables:
            raise KeyError("The netCDF file does not contain 'L1a_power_ddm' variable")
        if 'sp_alt' not in f.variables or 'sp_inc_angle' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_alt' or 'sp_inc_angle' variables")
        if 'sp_rx_gain_copol' not in f.variables or 'sp_rx_gain_xpol' not in f.variables or 'ddm_snr' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_rx_gain_copol', 'sp_rx_gain_xpol' or 'ddm_snr' variables")
        if 'sp_lat' not in f.variables or 'sp_lon' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_lat' or 'sp_lon' variables")
        if 'sp_surface_type' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_surface_type' variable")
        if 'ac_alt' not in f.variables:
            raise KeyError("The netCDF file does not contain 'ac_alt' variable")
        if f.variables['L1a_power_ddm'].ndim != 4:
            raise ValueError("The 'L1a_power_ddm' variable must have 4 dimensions")

    def preprocess(self, f):
        """ Preprocess the netCDF file and return fit data and labels """
        # Check integrity of the netCDF file
        self.check_integrity(f)

        raw_counts = f.variables['raw_counts'][:]
        ac_alt = f.variables['ac_alt'][:]
        sp_alt = f.variables['sp_alt'][:]
        copol = f.variables['sp_rx_gain_copol'][:]
        xpol = f.variables['sp_rx_gain_xpol'][:]
        snr = f.variables['ddm_snr'][:]
        sp_inc_angle = f.variables['sp_inc_angle'][:]

        distance_2d = (ac_alt[:, np.newaxis] - sp_alt) / np.cos(np.deg2rad(sp_inc_angle)) # Distance between the aircraft and the specular point

        # Filtering mask
        keep_mask = (
            (copol >= 5) & # # SP copolarized gain
            (xpol >= 5) & # SP cross-polarized gain
            (snr > 0) & # Positive signal-to-Noise Ratio
            (distance_2d >= 2000) & #SP distance min
            (distance_2d <= 10000) & #SP distance max
            ~np.isnan(copol) & 
            ~np.isnan(xpol) & 
            ~np.isnan(snr) & 
            ~np.isnan(distance_2d)
        )

        output_array = np.full(raw_counts.shape, np.nan, dtype=np.float32)
        i_indices, j_indices = np.where(keep_mask)
        output_array[i_indices, j_indices] = raw_counts[i_indices, j_indices]

        n_time, n_samples = raw_counts.shape[:2]
        raw_counts_reshaped = output_array.reshape(n_time * n_samples, *raw_counts.shape[2:])
        del output_array

        # Filter out NaN and zero-sum rows
        valid_mask = ~np.any(np.isnan(raw_counts_reshaped), axis=(1, 2)) & (np.sum(raw_counts_reshaped, axis=(1, 2)) > 0)
        fit_data = raw_counts_reshaped[valid_mask].reshape(valid_mask.sum(), -1)

        surface_types = np.nan_to_num(f.variables["sp_surface_type"][:], nan=0).ravel()
        label_data = np.isin(surface_types, np.arange(1, 8)).astype(np.int32)
        label_data = label_data[valid_mask]

        # Ensure that fit_data and label_data have the same length
        assert fit_data.shape[0] == len(label_data), \
            f"Shape mismatch: fit_data {fit_data.shape[0]}, label_data {len(label_data)}"

        return fit_data, label_data

    def preprocess_w_lat_lons(self, f):
        """ Version of the preprocessing function returning latitude and longitudes of the specular points """

        self.check_integrity(f)
        raw_counts = np.array(f.variables['raw_counts'])

        # Distance between the aircraft and the specular point
        ac_alt_2d = np.repeat(np.array(f.variables['ac_alt'])[:, np.newaxis], 20, axis=1)
        distance_2d = (ac_alt_2d - f.variables['sp_alt'][:]) / np.cos(np.deg2rad(f.variables['sp_inc_angle'][:]))

        copol = f.variables['sp_rx_gain_copol'][:]
        xpol = f.variables['sp_rx_gain_xpol'][:]
        snr = f.variables['ddm_snr'][:]
        dist = distance_2d[:]
        specular_point_lat = f.variables['sp_lat'][:]
        specular_point_lon = f.variables['sp_lon'][:]

        # Filtering with mask
        keep_mask = (copol >= 5) & (xpol >= 5) & (snr > 0) & ((dist >= 2000) & (dist <= 10000)) & (~np.isnan(copol.data) & ~np.isnan(xpol.data) & ~np.isnan(snr.data) & ~np.isnan(dist.data) & ~np.isnan(specular_point_lat.data) & ~np.isnan(specular_point_lon.data))
        to_keep_indices = np.argwhere(keep_mask)

        filtered_raw_counts = [raw_counts[i, j] for i, j in to_keep_indices]
        output_array = np.full(raw_counts.shape, np.nan, dtype=np.float32)

        specular_point_lats = specular_point_lat[to_keep_indices[:, 0]]
        specular_point_lons = specular_point_lon[to_keep_indices[:, 0]]

        for idx, (i, j) in enumerate(to_keep_indices):
            output_array[i, j] = filtered_raw_counts[idx]
        # Reshape the output array to match the original dimensions
            raw_counts_filtered = output_array.copy()

        raw_counts_filtered = output_array.copy()
        del output_array

        ddm_data_dict = {
            'Raw_Counts': raw_counts_filtered.reshape(raw_counts_filtered.shape[0]*raw_counts_filtered.shape[1], raw_counts_filtered.shape[2], raw_counts_filtered.shape[3]),
        }
        keep_indices = np.where(
            np.all(~np.isnan(ddm_data_dict['Raw_Counts']), axis=(1, 2)) & (np.sum(ddm_data_dict['Raw_Counts'], axis=(1, 2)) > 0)
        )[0]
        fit_data = np.array([ddm_data_dict['Raw_Counts'][f].ravel() for f in keep_indices])

        specular_point_lats = specular_point_lat.ravel()[keep_indices]
        specular_point_lons = specular_point_lon.ravel()[keep_indices]

        surface_types = f.variables["sp_surface_type"][:]
        surface_types = np.nan_to_num(surface_types, nan=0)
        surface_types_unravelled = surface_types.ravel()
        label_data = [1 if surface_type in np.arange(1, 8) else 0 for surface_type in surface_types_unravelled]
        label_data = [label_data[lab] for lab in range(len(label_data)) if lab in keep_indices]

        assert np.array(fit_data).shape[0] == len(label_data) == np.array(specular_point_lats).shape[0] == np.array(specular_point_lons).shape[0], \
            f"Shape mismatch: fit_data {np.array(fit_data).shape[0]}, label_data {len(label_data)}, lats {np.array(specular_point_lats).shape[0]}, lons {np.array(specular_point_lons).shape[0]}"


        return fit_data, label_data, specular_point_lats, specular_point_lons

    def preprocess_snr_unfiltered(self, f):
        """ Preprocess the netCDF file and return fit data and labels without filtering on signal-to-noise ratio """
        # Check integrity of the netCDF file
        self.check_integrity(f)

        L1a_power_ddm = f.variables['L1a_power_ddm'][:]
        ac_alt = f.variables['ac_alt'][:]
        sp_alt = f.variables['sp_alt'][:]
        sp_inc_angle = f.variables['sp_inc_angle'][:]
        copol = f.variables['sp_rx_gain_copol'][:]
        xpol = f.variables['sp_rx_gain_xpol'][:]
        #snr = f.variables['ddm_snr'][:]
        
        #Distance between the aircraft and the specular point
        distance_2d = (ac_alt[:, np.newaxis] - sp_alt) / np.cos(np.deg2rad(sp_inc_angle))
        # Filtering mask without SNR
        keep_mask = (
            (copol >= 5) & 
            (xpol >= 5) & 
        #   (snr > 0)  &
            (distance_2d >= 2000) & 
            (distance_2d <= 10000) &
            ~np.isnan(copol) & 
            ~np.isnan(xpol) & 
            #~np.isnan(snr) & 
            ~np.isnan(distance_2d)
        )

        output_array = np.full(L1a_power_ddm.shape, np.nan, dtype=np.float32)
        i_indices, j_indices = np.where(keep_mask)
        output_array[i_indices, j_indices] = L1a_power_ddm[i_indices, j_indices]

        n_time, n_samples = L1a_power_ddm.shape[:2]
        L1a_power_ddm_reshaped = output_array.reshape(n_time * n_samples, *L1a_power_ddm.shape[2:])
        del output_array
        valid_mask = ~np.any(np.isnan(L1a_power_ddm_reshaped), axis=(1, 2)) & (np.sum(L1a_power_ddm_reshaped, axis=(1, 2)) > 0)
        fit_data = L1a_power_ddm_reshaped[valid_mask].reshape(valid_mask.sum(), -1)

        surface_types = np.nan_to_num(f.variables["sp_surface_type"][:], nan=0).ravel()
        label_data = np.isin(surface_types, np.arange(1, 8)).astype(np.int32)
        label_data = label_data[valid_mask]
        # Ensure that fit_data and label_data have the same length
        assert fit_data.shape[0] == len(label_data), \
            f"Shape mismatch: fit_data {fit_data.shape[0]}, label_data {len(label_data)}"

        return fit_data, label_data


    def process_all_files_random_picked(self, chunk_size = int, sample_fraction = float, n_files_to_pick= int, remove_chunks= bool):
        """ Process all netCDF files in the directory, randomly picking a specified number of files,
        and save the processed data and labels in chunks."""

        full_data = []
        full_labels = []
        counter = 0

        # Take a random number of netCDF files
        if int(len(self.netcdf_file_list)) > n_files_to_pick: # type: ignore
            np.random.seed(42)
            random_netcdf_selected_files = np.random.choice(self.netcdf_file_list, n_files_to_pick, replace=False) # type: ignore
            print('Selezionati 500 file netCDF casuali dalla lista')
        else:
            random_netcdf_selected_files = self.netcdf_file_list

        for file_name in tqdm(random_netcdf_selected_files, desc="Processing files"):
            if not file_name.endswith('.nc'):
                continue
            try:
                f = netCDF4.Dataset(f'{self.root_dir}{file_name}')
                if self.preprocessing_method == 'unfiltered':
                    data, labels = self.preprocess_snr_unfiltered(f)
                elif self.preprocessing_method == 'with_lat_lons':
                    data, labels, latitudes, longitudes = self.preprocess_w_lat_lons(f)
                else:
                    # Default to filtered preprocessing
                    data, labels = self.preprocess(f)
                assert (len(data) == len(labels)), f"Data and labels length mismatch in file {file_name}: {len(data)} != {len(labels)}"
                full_data.append(data)
                full_labels.append(labels)
            except Exception as e:
                print(f"Error processing file {file_name}: {e}")
                continue
            counter += 1
            if counter == n_files_to_pick:
                break
        print(f"Processed {counter} files out of {len(random_netcdf_selected_files)} selected files.")
        # Filtering on data shape
        valid_indices = [i for i, arr in enumerate(full_data) if arr.ndim == 2 if arr.shape[1] == 200]
        full_data_clean = [full_data[i] for i in valid_indices]
        full_labels_clean = [full_labels[i] for i in valid_indices]
        print(f"Number of valid data arrays after filtering: {len(full_data_clean)}")
        # Chunking 
        os.makedirs('geok_test_data/binary_classification', exist_ok=True)

        full_data_sampled = []
        full_labels_sampled = []

        num_chunks = int(np.ceil(len(full_data_clean) / chunk_size))  # type: ignore
        print(f"Total number of chunks: {num_chunks}")
        for idx in range(num_chunks):
            start = idx * chunk_size # type: ignore
            end = min((idx + 1) * chunk_size, len(full_data_clean)) # type: ignore
            chunk_data = np.vstack(full_data_clean[start:end])
            chunk_labels = np.hstack(full_labels_clean[start:end])
            if chunk_data.size == 0 or chunk_labels.size == 0:
                print(f"Skipping empty chunk {idx + 1}/{num_chunks}")
                continue
            print(f"Chunk {idx + 1}/{num_chunks} processed with shape {chunk_data.shape} and labels shape {chunk_labels.shape}")

            # Save each chunk to parquet files
            if chunk_data.shape[0] == 0 or chunk_labels.shape[0] == 0:
                print(f"Skipping empty chunk {idx + 1}/{num_chunks}")
                continue
            fit_data_df = pd.DataFrame(chunk_data)
            labels_df = pd.DataFrame(chunk_labels, columns=['label'])

            table_fit = pa.Table.from_pandas(fit_data_df, preserve_index=False)
            table_labels = pa.Table.from_pandas(labels_df, preserve_index=False)

            pq.write_table(
                table_fit,
                f'geok_test_data/binary_classification/fit_data_chunk_{idx}.parquet',
                compression='zstd',
                use_dictionary=True,
            )
            pq.write_table(
                table_labels,
                f'geok_test_data/binary_classification/labels_chunk_{idx}.parquet',
                compression='zstd',
                use_dictionary=True,
            )
            # Stratified sampling from each chunk
            _, X_sampled, _, y_sampled = train_test_split(
                chunk_data, chunk_labels,
                test_size=sample_fraction,  # type: ignore
                stratify=chunk_labels,
                random_state=42
            )

            full_data_sampled.append(X_sampled)
            full_labels_sampled.append(y_sampled)

        del full_data, full_labels

        full_data_sampled_stratified = np.vstack(full_data_sampled)
        full_labels_sampled_stratified = np.hstack(full_labels_sampled)

        del full_data_sampled, full_labels_sampled
        print(f"Shape of sampled data after chunking and sampling: {np.array(full_data_sampled_stratified).shape}")
        print(f"Shape of sampled labels after chunking and sampling: {np.array(full_labels_sampled_stratified).shape}")

        # Save the final sampled data and labels in parquet format
        if not os.path.exists('geok_test_data/binary_classification'):
            print("Creating directory geok_test_data/binary_classification")
            os.makedirs('geok_test_data/binary_classification', exist_ok=True)

        # Save fit_data 
        fit_data_df = pd.DataFrame(full_data_sampled_stratified)
        table_fit = pa.Table.from_pandas(fit_data_df, preserve_index=False)
        pq.write_table(
            table_fit,
            'geok_test_data/binary_classification/fit_data_binary_test.parquet',
            compression='zstd',
            use_dictionary=True,
        )
        # Save labels
        labels_df = pd.DataFrame(full_labels_sampled_stratified, columns=['label'])
        table_labels = pa.Table.from_pandas(labels_df, preserve_index=False)
        pq.write_table(
            table_labels,
            'geok_test_data/binary_classification/labels_binary_test.parquet',
            compression='zstd',
            use_dictionary=True,
        )
        # Clean up
        del fit_data_df, labels_df, table_fit, table_labels

        print("Data and labels saved in geok_test_data/binary_classification directory.")
        # Remove all chunk parquet files if flag is set (to save space)
        if remove_chunks:
            try:
                chunk_dir = 'geok_test_data/binary_classification'
                for fname in os.listdir(chunk_dir):
                    if fname.startswith('fit_data_chunk_') or fname.startswith('labels_chunk_'):
                        os.remove(os.path.join(chunk_dir, fname))
                print("All chunk files removed.")
            except Exception as e:
                print(f"Error removing chunk files: {e}")

        return full_data_sampled_stratified, full_labels_sampled_stratified

In [None]:
class DDMProcessor:
    """
    Class for processing and compressing DDM (Delay Doppler Map) files from NetCDF format.
    """
    @staticmethod
    def check_integrity(f):
        """Check integrity of the netCDF file"""
        if not isinstance(f, netCDF4.Dataset):
            raise ValueError("Input must be a netCDF4.Dataset object")
        if 'raw_counts' not in f.variables:
            raise KeyError("The netCDF file does not contain 'raw_counts' variable")
        if 'sp_alt' not in f.variables or 'sp_inc_angle' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_alt' or 'sp_inc_angle' variables")
        if 'sp_rx_gain_copol' not in f.variables or 'sp_rx_gain_xpol' not in f.variables or 'ddm_snr' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_rx_gain_copol', 'sp_rx_gain_xpol' or 'ddm_snr' variables")
        if 'sp_lat' not in f.variables or 'sp_lon' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_lat' or 'sp_lon' variables")
        if 'sp_surface_type' not in f.variables:
            raise KeyError("The netCDF file does not contain 'sp_surface_type' variable")
        if 'ac_alt' not in f.variables:
            raise KeyError("The netCDF file does not contain 'ac_alt' variable")
        if f.variables['raw_counts'].ndim != 4:
            raise ValueError("The 'raw_counts' variable must have 4 dimensions")
        
    def __init__(self, input_folder, output_folder, device=None):
        """
        Initialize the DDM Processor.
        
        Args:
            input_folder (str): Path to folder containing input NetCDF files
            output_folder (str): Path to folder where compressed files will be saved
            device (torch.device): Device for computation (cuda/cpu)
        """
        self.input_folder = Path(input_folder)
        self.output_folder = Path(output_folder)
        
        # Create output folder if it doesn't exist
        self.output_folder.mkdir(parents=True, exist_ok=True)
        
        # Set device
        if device is None:
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        else:
            self.device = device
            
        print(f"Using device: {self.device}")
        
        # Initialize scaler
        self.scaler = MinMaxScaler()
        self.scaler_fitted = False
        
        # Placeholder for encoder model (to be loaded or set)
        self.encoder = None
        
    def set_encoder(self, encoder_model):
        """
        Set the encoder model for compression.
        
        Args:
            encoder_model: PyTorch model for encoding/compression
        """
        self.encoder = encoder_model.to(self.device)
        self.encoder.eval()
        
    def load_encoder(self, model_path):
        """
        Load a pre-trained encoder model.
        
        Args:
            model_path (str): Path to the saved encoder model
        """
        # Example implementation - adjust based on your model architecture
        self.encoder = torch.load(model_path, map_location=self.device)
        self.encoder.eval()
        
    def process_single_file(self, file_path):
        """
        Process a single NetCDF file and extract DDM data.
        
        Args:
            file_path (str): Path to the NetCDF file
            
        Returns:
            np.ndarray: Processed DDM data
        """
        ddm_list = []
        #kept_track_id = []
        #discarded_track_id = []
        #data_dict = {}
        # Load NetCDF dataset
        try:
            ds = netCDF4.Dataset(file_path)
            self.check_integrity(ds)
            #data_dict['file_name'] = file_name
            #L1a_power_ddm = ds["L1a_power_ddm"][:]  # Shape (N, 20, 40, 5)

            track_id = ds["track_id"][:] # Shape (N,)

        except Exception as e:
            print(f"Error loading {file_path}: {e}")
            return None, None

        # Calcola la stessa maschera usata dal preprocessor
        #preprocessor.check_integrity(f)

        ac_alt = ds['ac_alt'][:]
        sp_alt = ds['sp_alt'][:]
        sp_inc_angle = ds['sp_inc_angle'][:]
        copol = ds['sp_rx_gain_copol'][:]
        xpol = ds['sp_rx_gain_xpol'][:]
        L1a_power_ddm = ds['L1a_power_ddm'][:]
        surface = ds.variables["sp_surface_type"][:]

        try:
            assert L1a_power_ddm.shape[0] == track_id.shape[0], "DDM and track_id must have the same first dimension" # type: ignore
        except AssertionError as e:
            print(e)

        # Stessa logica del preprocessor per la distanza
        distance_2d = (ac_alt[:, np.newaxis] - sp_alt) / np.cos(np.deg2rad(sp_inc_angle))

        # Stessa maschera di filtering (senza SNR per unfiltered)
        keep_mask = (
            (copol >= 5) & 
            (xpol >= 5) & 
            (distance_2d >= 2000) & 
            (distance_2d <= 10000) &
            ~np.isnan(copol) & 
            ~np.isnan(xpol) & 
            ~np.isnan(distance_2d) &
            ~np.isnan(surface)  
            
        )

        # Applica la maschera per creare l'array filtrato
        output_array = np.full(L1a_power_ddm.shape, np.nan, dtype=np.float32)
        i_indices, j_indices = np.where(keep_mask)
        output_array[i_indices, j_indices] = L1a_power_ddm[i_indices, j_indices]

        n_time, n_samples = L1a_power_ddm.shape[:2]
        L1a_power_ddm_reshaped = output_array.reshape(n_time * n_samples, *L1a_power_ddm.shape[2:])

        # Filter out NaN and zero-sum rows
        valid_mask = ~np.any(np.isnan(L1a_power_ddm_reshaped), axis=(1, 2)) & (np.sum(L1a_power_ddm_reshaped, axis=(1, 2)) > 0)


        surface_types = np.nan_to_num(surface, nan=0).ravel()
        label_data = np.isin(surface_types, np.arange(1, 8)).astype(np.int32)
        label_data = label_data[valid_mask]

        # Process DDM data
        label_data_check = []
        for sample in range(output_array.shape[0]):
            for channel in range(output_array.shape[1]):
                ddm_sample = output_array[sample, channel].reshape(200,)  # Shape (200,)
                if max(ddm_sample) > 0:
                    ddm_list.append(ddm_sample)
                    label = surface[sample, channel]
                    if label in np.arange(1, 8):
                        label = 1
                    else:
                        label = 0
                
                    label_data_check.append(label)
                    #kept_track_id.append((sample, channel))
                    
                else:
                    
                    #discarded_track_id.append((sample, channel)) # type: ignore #
                    continue
        del output_array
        
        #ata_dict['kept_track_id'] = np.array(kept_track_id)
        #data_dict['discarded_track_id'] = np.array(discarded_track_id)

        if len(ddm_list) == 0:
            print(f"No valid DDM samples found in {file_path.name}")
            
        # Stack all DDM samples
        ddm_data_raw = np.stack(ddm_list)
        #data_dict['ddm_data_raw'] = ddm_data_raw
        #print(f"  Extracted {ddm_data_raw.shape[0]} valid DDM samples")
        try:
            assert label_data.shape[0] == ddm_data_raw.shape[0], \
                f"Shape mismatch: label_data {label_data.shape[0]}, ddm_data_raw {ddm_data_raw.shape[0]}"
        except AssertionError as e:
            print(f"Error in {file_path.name}: {e}")
            return None, None
            
        
        return ddm_data_raw, label_data
    
    def normalize_data(self, ddm_data_raw):
        """
        Normalize DDM data to [0, 1] range.
        
        Args:
            ddm_data_raw (np.ndarray): Raw DDM data
            
        Returns:
            np.ndarray: Normalized DDM data
        """
        # Scale the data
        ddm_data = self.scaler.fit_transform(ddm_data_raw * 1e13)
        self.scaler_fitted = True
        
        return ddm_data
    
    def compress_data(self, tensor_data):
        """
        Compress data using the encoder model.
        
        Args:
            tensor_data (torch.Tensor): Input tensor data
            
        Returns:
            np.ndarray: Compressed data
        """
        if self.encoder is None:
            raise ValueError("Encoder model not set. Use set_encoder() or load_encoder() first.")
        
        # Create dataset and dataloader
        dataset = TensorDataset(tensor_data)
        dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
        
        compressed_data = []
        
        # Compress data in batches
        with torch.no_grad():
            for batch in dataloader:
                inputs = batch[0].to(self.device)
                compressed = self.encoder(inputs)
                compressed_data.append(compressed.cpu().numpy())
        
        # Concatenate all compressed batches
        compressed_array = np.concatenate(compressed_data, axis=0)
        
        return compressed_array
    
    def save_compressed_data(self, compressed_data, original_filename):
        """
        Save compressed data to output folder.
        
        Args:
            compressed_data (np.ndarray): Compressed data to save
            original_filename (str): Original filename (without extension)
        """
        output_filename = f"{original_filename}_compressed.npz"
        output_path = self.output_folder / output_filename
        
        # Save compressed data using numpy compressed format
        np.savez_compressed(output_path, data=compressed_data)
        #print(f"  Saved compressed data to {output_path}")
        
    def process_all_files(self, file_extension='.nc', save_scaler=True):
        """
        Process all NetCDF files in the input folder.
        
        Args:
            file_extension (str): Extension of files to process (default: '.nc')
            save_scaler (bool): Whether to save the scaler for future use
        """
        from collections import defaultdict


        # Get all files with specified extension
        file_list = list(self.input_folder.glob(f'*{file_extension}'))[:100]# Limit to first 50 files for testing

        
        if len(file_list) == 0:
            print(f"No files with extension '{file_extension}' found in {self.input_folder}")
            return
        
        print(f"Found {len(file_list)} files to process")
        full_data_dict = defaultdict(dict)
        # Process each file
        for file_path in tqdm(file_list, desc="Processing files"):
            if not file_path.is_file():
                continue
            # Step 1: Load and process DDM data
            ddm_data_raw, label_data = self.process_single_file(file_path) # type: ignore
            #full_data_dict[data_dict['file_name']] = data_dict
            if ddm_data_raw is None:
                continue
            if label_data is None:
                continue
            
            # Step 2: Normalize data
            ddm_data_normalized = self.normalize_data(ddm_data_raw)
            
            # Step 3: Convert to tensor
            tensor_data = torch.tensor(ddm_data_normalized, dtype=torch.float32)
            
            # Step 4: Compress data (if encoder is available)
            if self.encoder is not None:
                compressed_data = self.compress_data(tensor_data)
                
                # Step 5: Save compressed data
                filename_without_ext = file_path.stem
                self.save_compressed_data(compressed_data, filename_without_ext)
                #print(f"  Saving normalized data to {filename_without_ext}_normalized.npz")

                full_data_dict[str(filename_without_ext)]['compressed_data'] = compressed_data # type: ignore
                full_data_dict[str(filename_without_ext)]['labels'] = label_data  # type: ignore


            else:
                # If no encoder, save normalized data
                print("  No encoder set - saving normalized data instead")
                filename_without_ext = file_path.stem
               
                output_filename = f"{filename_without_ext}_normalized.npz"
                output_path = self.output_folder / output_filename
                np.savez_compressed(output_path, data=ddm_data_normalized)
                print(f"  Saved normalized data to {output_path}")
            
            
        
        # Save scaler for future use
        if save_scaler and self.scaler_fitted:
            scaler_path = self.output_folder / "scaler_encoder.pkl"
            joblib.dump(self.scaler, scaler_path)
            print(f"\nScaler saved to {scaler_path}")
        
        print(f"\n{'='*50}")
        print(f"Processing complete! Output files saved to {self.output_folder}")
        return full_data_dict # type: ignore

        


In [29]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(200, 100),
            nn.ReLU(),
            nn.Linear(100, 20),
            nn.ReLU()
        )

    def forward(self, x):
        return self.net(x)

class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(20, 100),
            nn.ReLU(),
            nn.Linear(100, 200)
        )

    def forward(self, x):
        return self.net(x)

# ----------------------
# Load the saved models
# ----------------------
def load_model(model_class, path):
    model = model_class().to(device)
    model.load_state_dict(torch.load(path, map_location=device))
    model.eval()
    return model


In [30]:
import json

# Example usage
if __name__ == "__main__":
    # Define input and output folders
    input_folder = "E:/data/RONGOWAI_L1_SDR_V1.0"
    output_folder = "E:/data/geo_k_compressed"
    
    # Initialize processor
    processor = DDMProcessor(input_folder, output_folder)
    
    #Load a pre-trained encoder
    encoder = load_model(Encoder, "C:/Users/atogni/Desktop/rongowai/geo_k/encoder_all_surface2.pth")
    processor.set_encoder(encoder)
    
    # Process all NetCDF files in the input folder
    full_data_dict = processor.process_all_files(file_extension='.nc', save_scaler=True)

Using device: cuda
Found 100 files to process


Processing files:   1%|          | 1/100 [00:00<01:04,  1.53it/s]

  Saving normalized data to 20221026-100450_NZRO-NZAA_L1_normalized.npz


Processing files:   2%|▏         | 2/100 [00:01<00:53,  1.84it/s]

  Saving normalized data to 20221026-112526_NZAA-NZTG_L1_normalized.npz


Processing files:   3%|▎         | 3/100 [00:01<00:53,  1.81it/s]

  Saving normalized data to 20221026-122454_NZTG-NZAA_L1_normalized.npz


Processing files:   4%|▍         | 4/100 [00:02<01:08,  1.40it/s]

  Saving normalized data to 20221026-134535_NZAA-NZGS_L1_normalized.npz


Processing files:   5%|▌         | 5/100 [00:03<01:20,  1.18it/s]

  Saving normalized data to 20221026-150826_NZGS-NZAA_L1_normalized.npz


Processing files:   6%|▌         | 6/100 [00:04<01:19,  1.18it/s]

  Saving normalized data to 20221026-165902_NZAA-NZAP_L1_normalized.npz


Processing files:   7%|▋         | 7/100 [00:05<01:17,  1.20it/s]

  Saving normalized data to 20221026-181252_NZAP-NZAA_L1_normalized.npz


Processing files:   8%|▊         | 8/100 [00:06<01:33,  1.02s/it]

  Saving normalized data to 20221026-202809_NZAA-NZNS_L1_normalized.npz


Processing files:   9%|▉         | 9/100 [00:08<01:41,  1.11s/it]

  Saving normalized data to 20221027-061109_NZNS-NZAA_L1_normalized.npz


Processing files:  10%|█         | 10/100 [00:09<01:35,  1.06s/it]

  Saving normalized data to 20221027-082948_NZAA-NZGS_L1_normalized.npz


Processing files:  11%|█         | 11/100 [00:10<01:35,  1.07s/it]

  Saving normalized data to 20221027-095646_NZGS-NZAA_L1_normalized.npz


Processing files:  12%|█▏        | 12/100 [00:11<01:41,  1.15s/it]

  Saving normalized data to 20221027-123430_NZAA-NZNS_L1_normalized.npz


Processing files:  13%|█▎        | 13/100 [00:12<01:31,  1.05s/it]

  Saving normalized data to 20221027-143322_NZNS-NZCH_L1_normalized.npz


Processing files:  14%|█▍        | 14/100 [00:12<01:20,  1.07it/s]

  Saving normalized data to 20221027-163735_NZCH-NZHK_L1_normalized.npz


Processing files:  15%|█▌        | 15/100 [00:13<01:07,  1.26it/s]

  Saving normalized data to 20221027-174645_NZHK-NZCH_L1_normalized.npz


Processing files:  16%|█▌        | 16/100 [00:14<01:05,  1.29it/s]

  Saving normalized data to 20221028-083519_NZCH-NZNS_L1_normalized.npz


Processing files:  17%|█▋        | 17/100 [00:15<01:11,  1.16it/s]

  Saving normalized data to 20221028-103953_NZNS-NZCH_L1_normalized.npz


Processing files:  18%|█▊        | 18/100 [00:16<01:08,  1.20it/s]

  Saving normalized data to 20221028-125424_NZCH-NZNS_L1_normalized.npz


Processing files:  19%|█▉        | 19/100 [00:16<00:57,  1.41it/s]

  Saving normalized data to 20221028-144937_NZNS-NZWN_L1_normalized.npz


Processing files:  20%|██        | 20/100 [00:17<01:00,  1.32it/s]

  Saving normalized data to 20221028-161003_NZWN-NZNS_L1_normalized.npz


Processing files:  21%|██        | 21/100 [00:17<00:50,  1.57it/s]

  Saving normalized data to 20221028-173441_NZNS-NZWN_L1_normalized.npz


Processing files:  22%|██▏       | 22/100 [00:18<00:59,  1.31it/s]

  Saving normalized data to 20221028-185455_NZWN-NZGS_L1_normalized.npz


Processing files:  23%|██▎       | 23/100 [00:19<01:04,  1.20it/s]

  Saving normalized data to 20221029-063152_NZGS-NZAA_L1_normalized.npz


Processing files:  24%|██▍       | 24/100 [00:20<01:01,  1.24it/s]

  Saving normalized data to 20221029-082733_NZAA-NZAP_L1_normalized.npz


Processing files:  25%|██▌       | 25/100 [00:21<01:03,  1.18it/s]

  Saving normalized data to 20221029-100254_NZAP-NZAA_L1_normalized.npz


Processing files:  26%|██▌       | 26/100 [00:22<01:00,  1.23it/s]

  Saving normalized data to 20221029-112638_NZAA-NZAP_L1_normalized.npz


Processing files:  27%|██▋       | 27/100 [00:22<01:00,  1.21it/s]

  Saving normalized data to 20221029-123547_NZAP-NZAA_L1_normalized.npz


Processing files:  28%|██▊       | 28/100 [00:23<00:52,  1.37it/s]

  Saving normalized data to 20221029-141744_NZAA-NZWR_L1_normalized.npz


Processing files:  29%|██▉       | 29/100 [00:23<00:45,  1.57it/s]

  Saving normalized data to 20221029-151744_NZWR-NZAA_L1_normalized.npz


Processing files:  30%|███       | 30/100 [00:24<00:41,  1.68it/s]

  Saving normalized data to 20221029-165836_NZAA-NZWR_L1_normalized.npz


Processing files:  31%|███       | 31/100 [00:24<00:39,  1.75it/s]

  Saving normalized data to 20221101-064735_NZAA-NZTG_L1_normalized.npz


Processing files:  32%|███▏      | 32/100 [00:26<00:49,  1.36it/s]

  Saving normalized data to 20221101-074443_NZTG-NZWN_L1_normalized.npz


Processing files:  33%|███▎      | 33/100 [00:27<00:58,  1.14it/s]

  Saving normalized data to 20221101-092955_NZWN-NZTG_L1_normalized.npz


Processing files:  34%|███▍      | 34/100 [00:27<00:51,  1.29it/s]

  Saving normalized data to 20221101-111110_NZTG-NZAA_L1_normalized.npz


Processing files:  35%|███▌      | 35/100 [00:28<00:51,  1.27it/s]

  Saving normalized data to 20221101-123329_NZAA-NZGS_L1_normalized.npz


Processing files:  36%|███▌      | 36/100 [00:29<00:55,  1.16it/s]

  Saving normalized data to 20221101-141321_NZGS-NZAA_L1_normalized.npz


Processing files:  37%|███▋      | 37/100 [00:31<01:08,  1.09s/it]

  Saving normalized data to 20221101-154832_NZAA-NZWB_L1_normalized.npz


Processing files:  38%|███▊      | 38/100 [00:32<01:09,  1.12s/it]

  Saving normalized data to 20221101-173952_NZWB-NZAA_L1_normalized.npz


Processing files:  39%|███▉      | 39/100 [00:33<00:57,  1.05it/s]

  Saving normalized data to 20221101-194320_NZAA-NZKK_L1_normalized.npz


Processing files:  40%|████      | 40/100 [00:33<00:51,  1.16it/s]

  Saving normalized data to 20221102-060644_NZKK-NZAA_L1_normalized.npz


Processing files:  41%|████      | 41/100 [00:34<00:49,  1.19it/s]

  Saving normalized data to 20221102-072913_NZAA-NZNP_L1_normalized.npz


Processing files:  42%|████▏     | 42/100 [00:35<00:45,  1.28it/s]

  Saving normalized data to 20221102-084043_NZNP-NZAA_L1_normalized.npz


Processing files:  43%|████▎     | 43/100 [00:35<00:45,  1.25it/s]

  Saving normalized data to 20221102-105117_NZAA-NZNP_L1_normalized.npz


Processing files:  44%|████▍     | 44/100 [00:36<00:43,  1.27it/s]

  Saving normalized data to 20221102-120951_NZNP-NZWN_L1_normalized.npz


Processing files:  45%|████▌     | 45/100 [00:36<00:34,  1.58it/s]

  Saving normalized data to 20221102-141125_NZWN-NZWB_L1_normalized.npz


Processing files:  46%|████▌     | 46/100 [00:37<00:28,  1.87it/s]

  Saving normalized data to 20221102-162321_NZWB-NZWN_L1_normalized.npz


Processing files:  47%|████▋     | 47/100 [00:37<00:24,  2.20it/s]

  Saving normalized data to 20221102-172207_NZWN-NZWB_L1_normalized.npz


Processing files:  48%|████▊     | 48/100 [00:37<00:21,  2.42it/s]

  Saving normalized data to 20221102-181605_NZWB-NZWN_L1_normalized.npz


Processing files:  49%|████▉     | 49/100 [00:39<00:35,  1.44it/s]

  Saving normalized data to 20221102-195358_NZWN-NZTG_L1_normalized.npz


Processing files:  50%|█████     | 50/100 [00:40<00:46,  1.08it/s]

  Saving normalized data to 20221103-065105_NZTG-NZWN_L1_normalized.npz


Processing files:  51%|█████     | 51/100 [00:42<01:02,  1.28s/it]

  Saving normalized data to 20221103-091459_NZWN-NZNV_L1_normalized.npz


Processing files:  52%|█████▏    | 52/100 [00:44<01:01,  1.27s/it]

  Saving normalized data to 20221103-121416_NZNV-NZCH_L1_normalized.npz


Processing files:  53%|█████▎    | 53/100 [00:45<01:09,  1.48s/it]

  Saving normalized data to 20221103-151412_NZCH-NZTG_L1_normalized.npz


Processing files:  54%|█████▍    | 54/100 [00:46<00:55,  1.20s/it]

  Saving normalized data to 20221103-173530_NZTG-NZAA_L1_normalized.npz


Processing files:  55%|█████▌    | 55/100 [00:47<00:49,  1.10s/it]

  Saving normalized data to 20221103-185508_NZAA-NZNR_L1_normalized.npz


Processing files:  56%|█████▌    | 56/100 [00:48<00:50,  1.14s/it]

  Saving normalized data to 20221104-071142_NZNR-NZAA_L1_normalized.npz


Processing files:  57%|█████▋    | 57/100 [00:49<00:41,  1.04it/s]

  Saving normalized data to 20221104-085735_NZAA-NZRO_L1_normalized.npz


Processing files:  58%|█████▊    | 58/100 [00:49<00:37,  1.13it/s]

  Saving normalized data to 20221104-100424_NZRO-NZAA_L1_normalized.npz


Processing files:  59%|█████▉    | 59/100 [00:51<00:39,  1.04it/s]

  Saving normalized data to 20221104-123446_NZAA-NZNS_L1_normalized.npz


Processing files:  60%|██████    | 60/100 [00:51<00:35,  1.12it/s]

  Saving normalized data to 20221104-144000_NZNS-NZCH_L1_normalized.npz


Processing files:  61%|██████    | 61/100 [00:52<00:32,  1.21it/s]

  Saving normalized data to 20221105-085244_NZCH-NZNS_L1_normalized.npz


Processing files:  62%|██████▏   | 62/100 [00:53<00:28,  1.32it/s]

  Saving normalized data to 20221105-101226_NZNS-NZWN_L1_normalized.npz


Processing files:  63%|██████▎   | 63/100 [00:53<00:25,  1.46it/s]

  Saving normalized data to 20221105-123553_NZWN-NZNS_L1_normalized.npz


Processing files:  64%|██████▍   | 64/100 [00:54<00:29,  1.23it/s]

  Saving normalized data to 20221105-143106_NZNS-NZAA_L1_normalized.npz


Processing files:  65%|██████▌   | 65/100 [00:55<00:28,  1.23it/s]

  Saving normalized data to 20221105-165902_NZAA-NZGS_L1_normalized.npz


Processing files:  66%|██████▌   | 66/100 [00:56<00:29,  1.14it/s]

  Saving normalized data to 20221106-093628_NZGS-NZAA_L1_normalized.npz


Processing files:  67%|██████▋   | 67/100 [00:56<00:24,  1.37it/s]

  Saving normalized data to 20221106-115743_NZAA-NZTG_L1_normalized.npz


Processing files:  68%|██████▊   | 68/100 [00:57<00:21,  1.48it/s]

  Saving normalized data to 20221106-125610_NZTG-NZAA_L1_normalized.npz


Processing files:  69%|██████▉   | 69/100 [00:58<00:23,  1.34it/s]

  Saving normalized data to 20221106-151123_NZAA-NZGS_L1_normalized.npz


Processing files:  70%|███████   | 70/100 [00:59<00:26,  1.14it/s]

  Saving normalized data to 20221106-163130_NZGS-NZWN_L1_normalized.npz


Processing files:  71%|███████   | 71/100 [01:00<00:22,  1.29it/s]

  Saving normalized data to 20221106-183016_NZWN-NZNS_L1_normalized.npz


Processing files:  72%|███████▏  | 72/100 [01:01<00:26,  1.07it/s]

  Saving normalized data to 20221107-060333_NZNS-NZAA_L1_normalized.npz


Processing files:  73%|███████▎  | 73/100 [01:02<00:24,  1.08it/s]

  Saving normalized data to 20221107-083011_NZAA-NZGS_L1_normalized.npz


Processing files:  74%|███████▍  | 74/100 [01:03<00:24,  1.05it/s]

  Saving normalized data to 20221107-095603_NZGS-NZAA_L1_normalized.npz


Processing files:  75%|███████▌  | 75/100 [01:03<00:19,  1.26it/s]

  Saving normalized data to 20221107-115912_NZAA-NZWR_L1_normalized.npz


Processing files:  76%|███████▌  | 76/100 [01:04<00:17,  1.41it/s]

  Saving normalized data to 20221107-130025_NZWR-NZAA_L1_normalized.npz


Processing files:  77%|███████▋  | 77/100 [01:04<00:15,  1.45it/s]

  Saving normalized data to 20221107-143614_NZAA-NZNP_L1_normalized.npz


Processing files:  78%|███████▊  | 78/100 [01:05<00:15,  1.42it/s]

  Saving normalized data to 20221107-154457_NZNP-NZAA_L1_normalized.npz


Processing files:  79%|███████▉  | 79/100 [01:06<00:13,  1.57it/s]

  Saving normalized data to 20221107-172323_NZAA-NZTG_L1_normalized.npz


Processing files:  80%|████████  | 80/100 [01:06<00:13,  1.53it/s]

  Saving normalized data to 20221107-181821_NZTG-NZAA_L1_normalized.npz


Processing files:  81%|████████  | 81/100 [01:07<00:14,  1.35it/s]

  Saving normalized data to 20221107-194029_NZAA-NZKK_L1_normalized.npz


Processing files:  82%|████████▏ | 82/100 [01:08<00:12,  1.40it/s]

  Saving normalized data to 20221108-060603_NZKK-NZAA_L1_normalized.npz


Processing files:  83%|████████▎ | 83/100 [01:09<00:12,  1.37it/s]

  Saving normalized data to 20221108-073331_NZAA-NZNP_L1_normalized.npz


Processing files:  84%|████████▍ | 84/100 [01:09<00:12,  1.31it/s]

  Saving normalized data to 20221108-084629_NZNP-NZAA_L1_normalized.npz


Processing files:  85%|████████▌ | 85/100 [01:10<00:10,  1.38it/s]

  Saving normalized data to 20221108-100736_NZAA-NZRO_L1_normalized.npz


Processing files:  86%|████████▌ | 86/100 [01:11<00:09,  1.44it/s]

  Saving normalized data to 20221108-110604_NZRO-NZAA_L1_normalized.npz


Processing files:  87%|████████▋ | 87/100 [01:11<00:09,  1.43it/s]

  Saving normalized data to 20221108-123914_NZAA-NZWR_L1_normalized.npz


Processing files:  88%|████████▊ | 88/100 [01:12<00:07,  1.55it/s]

  Saving normalized data to 20221108-140534_NZWR-NZAA_L1_normalized.npz


Processing files:  89%|████████▉ | 89/100 [01:12<00:06,  1.66it/s]

  Saving normalized data to 20221108-152653_NZAA-NZTG_L1_normalized.npz


Processing files:  90%|█████████ | 90/100 [01:14<00:10,  1.02s/it]

  Saving normalized data to 20221108-163135_NZTG-NZCH_L1_normalized.npz


Processing files:  91%|█████████ | 91/100 [01:16<00:09,  1.11s/it]

  Saving normalized data to 20221108-192429_NZCH-NZNP_L1_normalized.npz


Processing files:  92%|█████████▏| 92/100 [01:17<00:08,  1.03s/it]

  Saving normalized data to 20221109-070037_NZNP-NZWN_L1_normalized.npz


Processing files:  93%|█████████▎| 93/100 [01:17<00:06,  1.05it/s]

  Saving normalized data to 20221109-083828_NZWN-NZNP_L1_normalized.npz


Processing files:  94%|█████████▍| 94/100 [01:18<00:05,  1.14it/s]

  Saving normalized data to 20221109-103050_NZNP-NZWN_L1_normalized.npz


Processing files:  95%|█████████▌| 95/100 [01:19<00:04,  1.18it/s]

  Saving normalized data to 20221109-123628_NZWN-NZNP_L1_normalized.npz


Processing files:  96%|█████████▌| 96/100 [01:20<00:03,  1.15it/s]

  Saving normalized data to 20221109-140730_NZNP-NZWN_L1_normalized.npz


Processing files:  97%|█████████▋| 97/100 [01:20<00:02,  1.34it/s]

  Saving normalized data to 20221109-160642_NZWN-NZNS_L1_normalized.npz


Processing files:  98%|█████████▊| 98/100 [01:21<00:01,  1.51it/s]

  Saving normalized data to 20221109-171054_NZNS-NZWN_L1_normalized.npz


Processing files:  99%|█████████▉| 99/100 [01:21<00:00,  1.62it/s]

  Saving normalized data to 20221109-183125_NZWN-NZNS_L1_normalized.npz


Processing files: 100%|██████████| 100/100 [01:22<00:00,  1.22it/s]

  Saving normalized data to 20221110-065522_NZNS-NZWN_L1_normalized.npz

Scaler saved to E:\data\geo_k_compressed\scaler_encoder.pkl

Processing complete! Output files saved to E:\data\geo_k_compressed





In [31]:
with open("E:/data/geo_k_compressed/full_data_dict.json", "w") as f:
        json.dump(full_data_dict, f, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else str(x))

In [32]:
len(full_data_dict.keys())

100

In [2]:
import json

json_path = r"E:\data\geo_k_compressed\full_data_dict.json"
with open(json_path, "r") as f:
    loaded_data = json.load(f)

In [4]:
loaded_data.keys()

dict_keys(['20221110-065522_NZNS-NZWN_L1.nc'])

In [None]:
df = pd.DataFrame(loaded_data['20221026-150826_NZGS-NZAA_L1.nc']["compressed_data"])

In [None]:
df['labels'] = loaded_data['20221026-150826_NZGS-NZAA_L1.nc']["labels"]

In [None]:
full_data_dict['20221026-150826_NZGS-NZAA_L1.nc'][]


In [None]:
full_data_dict['20221026-150826_NZGS-NZAA_L1.nc']
    # Example: Access compressed data and labels for a specific file

In [None]:
import json

# Save full_data_dict to a JSON file
with open("full_data_dict.json", "w") as f:
    json.dump(full_data_dict, f, indent=2, default=lambda x: x.tolist() if hasattr(x, "tolist") else str(x))
    

In [None]:
len(full_data_dict['20221026-100450_NZRO-NZAA_L1.nc']['kept_track_id'])


In [None]:
full_data_dict['20221026-100450_NZRO-NZAA_L1.nc']['discarded_track_id']