In [1]:
import os
import pandas as pd
import numpy as np
import glob
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils, transforms
import matplotlib.pyplot as plt

In [2]:
plt.rcdefaults()

In [8]:
DATA_PATH = "data/dataset_pickle"
FILE_EXT = "pickle"

def load_shot(shotno, data_path, file_ext):
    file_path = os.path.join(data_path, f"{shotno}.{file_ext}")
    return pd.read_pickle(file_path)

class SpectrogramDataset(Dataset):
    def __init__(self, data_path, file_ext, window_size, transform=None):
        # data loading
        self.data_path = data_path
        self.file_ext = file_ext
        self.window_size = window_size
        self.transform = transform

        # Obtain all shot numbers
        self.all_shots = [int(os.path.basename(x.split(f".{file_ext}")[0]))
                          for x in glob.glob(os.path.join(data_path, f"*.{file_ext}"))]

    def __len__(self):
        return len(self.all_shots)

    def __getitem__(self, idx):
        shotno = self.all_shots[idx]

        # Load data for the experiment
        data_shot = load_shot(shotno, self.data_path, self.file_ext)

        # Extract inputs
        inputs = data_shot["x"]["spectrogram"]
        spec_even = inputs["EvenN"]
        spec_odd = inputs["OddN"]
        f = inputs["frequency"]
        t = inputs["time"]

        # Calculate the number of windows based on window size
        num_windows = spec_even.shape[1] // self.window_size

        # Extract windows along with their start and end indices
        windows = []
        for i in range(num_windows):
            start_idx = i * self.window_size
            end_idx = start_idx + self.window_size

            window_even = spec_even[:, start_idx:end_idx]
            window_odd = spec_odd[:, start_idx:end_idx]

            if self.transform:
                window_even = self.transform(window_even)
                window_odd = self.transform(window_odd)

            windows.append({
                'window_even': window_even,
                'window_odd': window_odd,
                'frequency': f,
                'time': t[start_idx:end_idx],
                'start_idx': start_idx,
                'end_idx': end_idx,
                'shotno': shotno
            })

        return windows

# Example usage:
# Define hyperparameters
DATA_PATH = "data/dataset_pickle"
FILE_EXT = "pickle"
window_size = 64  # Adjust as needed
batch_size = 32  # Adjust as needed

# Create dataset and dataloader
#transform = transforms.Compose([
    # Add any additional transformations you need
#])

dataset = SpectrogramDataset(DATA_PATH, FILE_EXT, window_size, transform=None)
dataloader = DataLoader(dataset = dataset, batch_size=batch_size, shuffle=True, num_workers=0)

In [52]:
first_data = dataset[93][31] # First item of the first batch of 32 windows (which all correspond to a SINGLE shotno)
# First number is the shotno, the second one is which window we are accessing for that shotno.

for key, item in first_data.items():
    if isinstance(item, (list, np.ndarray, torch.Tensor)):
        if isinstance(item, (np.ndarray, torch.Tensor)):
            print(key, item.shape)
        else:
            print(key, len(item))
    elif isinstance(item, int):
        print(key, item)

window_even (2280, 64)
window_odd (2280, 64)
frequency (2049,)
time (64,)
start_idx 1984
end_idx 2048
shotno 73023


In [34]:
# Iterate over the dataloader
for batch in dataloader:
    # Your training loop or processing logic here
    for sample in batch:
        print(sample['window_even'].shape)  # Adjust this based on the shape of your spectrograms
        print(sample['start_idx'], sample['end_idx'])
        print(sample['time'])
        print(sample['shotno'])


RuntimeError: stack expects each tensor to be equal size, but got [2100, 64] at entry 0 and [3016, 64] at entry 1