In [None]:
import os
import pandas as pd
import numpy as np
import glob
import torch
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from torchvision import transforms, utils, transforms
import matplotlib.pyplot as plt

In [None]:
plt.rcdefaults()

# Use Odd spectrograms

In [6]:
from torch.utils.data import Dataset, DataLoader
import os
import glob
import torch
import pandas as pd
import itertools

class SpectrogramDataset(Dataset):
    def __init__(self, data_path, file_ext, window_size, transform=None, time_step=5.12e-4):
        """
        spectrograms: List of spectrogram tensors
        instabilities: List of tuples (start_time, end_time) for each spectrogram
        window_size: Number of time steps in each slice <-> Time in each window = window_size * time_step
        overlap: Fraction of overlap between consecutive slices (0 to 1)
        time_step: Duration of each time step in ms
        data_path: Location of the data
        file_ext: File type of the data
        """
        self.data_path = data_path
        self.file_ext = file_ext
        self.window_size = window_size
        self.transform = transform
        self.time_step = time_step

        # Obtain all shot numbers
        self.data_files = [int(os.path.basename(x.split(f".{file_ext}")[0]))
                           for x in glob.glob(os.path.join(data_path, f"*.{file_ext}"))]

        # Precompute and store all windows with unique IDs using a dictionnary
        self.windows = self.compute_all_windows()

    def __len__(self):
        # Return the total number of windows (the total size of our dataset)
        print(len(self.windows))
        return len(self.windows)

    def __getitem__(self, idx):
        # Return a single window based on the provided idx (unique identifier)
        #found_dict = None # We need to find a dictionnary in the list depending on the value of the unique_id key
        #for my_dict in self.windows:
        #    if my_dict.get('unique_id') == idx:
        #        found_dict = my_dict
        #        break
        #return self.windows[idx]
        # Return a list of windows based on the provided idx (unique identifier)
        found_dict = next((my_dict for my_dict in self.windows if my_dict.get('unique_id') == idx), None)
        return found_dict

    def load_shot(self, shotno):
        file_path = os.path.join(self.data_path, f"{shotno}.{self.file_ext}")
        return pd.read_pickle(file_path)

    def compute_all_windows(self):
        windows = []
        unique_id = 0

        # For each experiment
        for shotno in self.data_files:
            data_shot = self.load_shot(shotno)

            spec_odd = torch.tensor(data_shot["x"]["spectrogram"]["OddN"], dtype=torch.float32).T
            
            frequency = data_shot["x"]["spectrogram"]["frequency"]
            time = data_shot["x"]["spectrogram"]["time"]

            num_windows = len(time) // self.window_size
            
            # Compute non-overlapping sliding windows for OddN
            for i in range(0, num_windows * self.window_size, self.window_size):
                start_idx = i
                end_idx = i + self.window_size

                slice_data = spec_odd[:, start_idx:end_idx]

                windows.append({
                    'unique_id': unique_id,
                    'window_odd': slice_data,
                    'frequency': frequency,
                    'time': time[start_idx:end_idx],
                    'start_idx': start_idx,
                    'end_idx': end_idx,
                    'shotno': shotno
                })

                unique_id += 1
        #print(f"Total number of windows = {unique_id + 1} and the number of elements in windows is: {len(windows)}")
        return windows

# Example usage
DATA_PATH = "data/dataset_pickle"
FILE_EXT = "pickle"
WINDOW_SIZE = 64 # Number of datapoints

dataset = SpectrogramDataset(data_path=DATA_PATH, file_ext=FILE_EXT, window_size=WINDOW_SIZE)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# Iterate through the DataLoader
#for idx, batch in enumerate(dataloader):
#    # Your training/inference code here
#    print(f"Batch {idx + 1} - Number of Windows: {len(batch)}") # it prints the number of elements in the dictionnary
#    print(batch)
#    break
for idx, batch in enumerate(dataloader):
    print(f"Batch {idx + 1} - Number of Windows: {len(batch)}")
    for window_dict in batch:
        print(window_dict)
    break

3908
3908
3908
3908
Batch 1 - Number of Windows: 7
unique_id
window_odd
frequency
time
start_idx
end_idx
shotno


In [None]:
class SpectrogramDataset(Dataset):
    def __init__(self, spectrograms, instabilities, window_length=20, overlap=0.5, time_step=5.12e-4):
        """
        spectrograms: List of spectrogram tensors
        instabilities: List of tuples (start_time, end_time) for each spectrogram
        slice_length: Number of time steps in each slice
        overlap: Fraction of overlap between consecutive slices (0 to 1)
        time_step: Duration of each time step in ms
        """
        self.spectrograms = spectrograms
        self.instabilities = instabilities
        self.window_length = slice_length
        self.step = int(slice_length * (1 - overlap))  # Step size for moving the window
        self.time_step = time_step

        # calculate total number of slices across all spectrograms
        self.total_slices = []
        for s in self.spectrograms:
            num_slices = 1 + (s.shape[1] - slice_length) // self.step
            self.total_slices.append(num_slices)

    def __len__(self):
        return sum(self.total_slices)

    def __getitem__(self, idx):
        spectrogram_idx = 0
        # iterate through the spectrograms find which spectrogram idx belongs to
        while idx >= self.total_slices[spectrogram_idx]:
            idx -= self.total_slices[spectrogram_idx]
            spectrogram_idx += 1
        # resulting idx is now the index of the slice in spectrogram spectrogram_idx

        # compute the starting position of the slice
        start = idx * self.sstep
        end = start + self.slice_length

        # compute the time range of the slice
        start_time = start * self.time_step
        end_time = start_time + self.slice_length * self.time_step

        # slice the spectrogram
        slice = self.spectrograms[spectrogram_idx][:, start:end]

        # determine the label based on instabilities
        label = 0
        for instability_start, instability_end in self.instabilities[spectrogram_idx]:
            if start_time < instability_end and end_time > instability_start:
                label = 1
                break

        return slice, label

# spectrograms is a list of 2D PyTorch tensors
# instabilities is a list of lists of tuples [(start_time, end_time), ...]
dataset = SpectrogramDataset(spectrograms, instabilities, overlap=0.5)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)


## Number of windows per batch

In [None]:
# dataset = Way of representing the dataset so it can be loaded with the dataloader

# dataloader = Used to load batches of data from our dataset

In [None]:
first_data = dataset[93][54] # First item of the first batch of 32 windows (which all correspond to a SINGLE shotno)
# First number is the shotno, the second one is which window we are accessing for that shotno.

for key, item in first_data.items():
    if isinstance(item, (list, np.ndarray, torch.Tensor)):
        if isinstance(item, (np.ndarray, torch.Tensor)):
            print(key, item.shape)
        else:
            print(key, len(item))
    elif isinstance(item, int):
        print(key, item)

In [None]:
idx_shotno = np.random.randint(1,94) # Choose a random experiment
idx_windowno = np.random.randint(1,batch_size)
# Plot a random window (even frequencies) from a random shotno
random_sample = dataset[idx_shotno][idx_windowno]
print(f"Experiment number: {random_sample['shotno']}, and window number: {idx_windowno}")

plot_spectrogram(random_sample["window_odd"], title = "Random window (even frequencies) from a random shotno",\
                time = random_sample["time"], frequency = random_sample["frequency"])

### Let's verify that this is correct using the real data

In [None]:
data_shot = load_shot(random_sample['shotno'], DATA_PATH, FILE_EXT)

# Extracting inputs
inputs = data_shot["x"]["spectrogram"]
spec_even = inputs["EvenN"]
spec_odd = inputs["OddN"]
f = inputs["frequency"]
t = inputs["time"]

In [None]:
plot_spectrogram(spec_even, "Even N", t, f)