In [29]:
import os
import numpy as np
import argparse
from itertools import product
import torch
from torch_geometric.data import Data
from torch_geometric.data.collate import collate
import itertools
from enum import Enum

# Constants definitions
MAX_SIMU_TOKENS = 5 

class PitchToken:
    SOS = 0  # Start of sequence, example value

class DurationToken:
    SOS = 0  # Start of sequence, example value

N_DISCRETE_VALUES = 128  
N_SAMPLES = 100  
N_TIMESTEPS = 100  

# EdgeTypes for simplicity
class EdgeTypes(Enum):
    TRACK = 0
    ONSET = 1
    NEXT = 2
    N_EDGE_TYPES = 3  # Total number of edge types

def generate_sample_data(num_samples, n_timesteps):
    np.random.seed(42)  
    emotions = np.random.randint(0, N_DISCRETE_VALUES, (num_samples, n_timesteps))
    locations = np.random.randint(0, N_DISCRETE_VALUES, (num_samples, n_timesteps))
    activities = np.random.randint(0, N_DISCRETE_VALUES, (num_samples, n_timesteps))
    modes = np.random.randint(0, N_DISCRETE_VALUES, (num_samples, n_timesteps))
    return emotions, locations, activities, modes

def preprocess_data(data, n_bars, resolution):
    """
    Preprocess the sample data to create content and structure tensors.
    """
    num_samples, n_timesteps = data[0].shape
    length = n_timesteps

    # Initialize content and structure tensors
    c_tensor = np.zeros((num_samples, length, MAX_SIMU_TOKENS, 2), np.int16)
    s_tensor = np.zeros((num_samples, length), dtype=bool)

    for i, sample in enumerate(zip(*data)):
        for t in range(length):
            c_tensor[i, t, 0, 0] = PitchToken.SOS  # Start of sequence
            c_tensor[i, t, 1, 0] = sample[0][t]  # Emotions as pitch
            c_tensor[i, t, 2, 0] = sample[1][t]  # Locations as pitch
            c_tensor[i, t, 3, 0] = sample[2][t]  # Activities as pitch
            c_tensor[i, t, 4, 0] = sample[3][t]  # Modes as pitch
            c_tensor[i, t, :, 1] = DurationToken.SOS  # Using SOS token for duration for simplicity
            s_tensor[i, t] = True  # Example, setting structure tensor

    # Apply sliding window to generate sequences
    sequences = []
    for i in range(0, length - n_bars * resolution + 1, resolution):
        seq_c_tensor = c_tensor[:, i:i + n_bars * resolution, :, :]
        seq_s_tensor = s_tensor[:, i:i + n_bars * resolution]
        sequences.append((seq_c_tensor, seq_s_tensor))
        print(f"seq_c_tensor shape: {seq_c_tensor.shape}, seq_s_tensor shape: {seq_s_tensor.shape}")
    return sequences

def save_preprocessed_data(filepath, sequences):
    """
    Save the preprocessed data sequences to a file.
    Each sequence is saved as a separate item in the npz file.
    """
    # Create a dict where key is the sequence index and value is the sequence data
    seq_dict = {}
    for i, (seq_c_tensor, seq_s_tensor) in enumerate(sequences):
        seq_dict[f'seq_c_{i}'] = seq_c_tensor
        seq_dict[f'seq_s_{i}'] = seq_s_tensor

    # Save each sequence tensor as a separate item in the npz file
    np.savez(filepath, **seq_dict)



In [30]:
args = argparse.Namespace(n_bars=2, resolution=8)

# Generate sample data
sample_data = generate_sample_data(N_SAMPLES, N_TIMESTEPS)

# Preprocess the data
sequences = preprocess_data(sample_data, args.n_bars, args.resolution)

# Save the preprocessed data
save_preprocessed_data("preprocessed_data.npz", sequences)

seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)


In [31]:
def get_track_edges(s_tensor, ones_idxs=None, node_labels=None):
    track_edges = []

    if ones_idxs is None:
        ones_idxs = torch.nonzero(s_tensor, as_tuple=True)

    if node_labels is None:
        node_labels = get_node_labels(s_tensor, ones_idxs)

    for track in range(s_tensor.size(0)):
        tss = list(ones_idxs[1][ones_idxs[0] == track])
        edge_type = EdgeTypes.TRACK.value + track
        edges = [
            (node_labels[track, t1].item(), node_labels[track, t2].item(), edge_type, t2 - t1)
            for t1, t2 in zip(tss[:-1], tss[1:])
        ]
        inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges]
        track_edges.extend(edges + inverse_edges)

    return torch.tensor(track_edges, dtype=torch.long) if track_edges else torch.empty((0, 4), dtype=torch.long)

def get_onset_edges(s_tensor, ones_idxs=None, node_labels=None):
    onset_edges = []

    if ones_idxs is None:
        ones_idxs = torch.nonzero(s_tensor, as_tuple=True)

    if node_labels is None:
        node_labels = get_node_labels(s_tensor, ones_idxs)

    for ts in range(s_tensor.size(1)):
        tracks = list(ones_idxs[0][ones_idxs[1] == ts])
        combinations = list(itertools.combinations(tracks, 2))
        edges = [
            (node_labels[track1, ts].item(), node_labels[track2, ts].item(), EdgeTypes.ONSET.value, 0)
            for track1, track2 in combinations
        ]
        inverse_edges = [(u, v, t, d) for (v, u, t, d) in edges]
        onset_edges.extend(edges + inverse_edges)

    return torch.tensor(onset_edges, dtype=torch.long) if onset_edges else torch.empty((0, 4), dtype=torch.long)

def get_next_edges(s_tensor, ones_idxs=None, node_labels=None):
    next_edges = []

    if ones_idxs is None:
        ones_idxs = torch.nonzero(s_tensor, as_tuple=True)

    if node_labels is None:
        node_labels = get_node_labels(s_tensor, ones_idxs)

    tss = torch.nonzero(torch.any(s_tensor.bool(), dim=0)).squeeze()
    if tss.dim() == 0:
        return torch.empty((0, 4), dtype=torch.long)

    for i in range(tss.size(0)-1):
        t1, t2 = tss[i], tss[i+1]
        t1_tracks = ones_idxs[0][ones_idxs[1] == t1]
        t2_tracks = ones_idxs[0][ones_idxs[1] == t2]
        tracks_product = list(itertools.product(t1_tracks, t2_tracks))
        tracks_product = [(track1, track2) for (track1, track2) in tracks_product if track1 != track2]
        edges = [
            (node_labels[track1, t1].item(), node_labels[track2, t2].item(), EdgeTypes.NEXT.value, t2 - t1)
            for track1, track2 in tracks_product
        ]

        next_edges.extend(edges)

    return torch.tensor(next_edges, dtype=torch.long) if next_edges else torch.empty((0, 4), dtype=torch.long)

def get_track_features(s_tensor):
    # Indices where the binary structure tensor is active
    ones_idxs = torch.nonzero(s_tensor, as_tuple=True)

    n_nodes = len(ones_idxs[0])
    tracks = ones_idxs[0]  # Track indices for each active node

    # Assuming N_TRACKS is the total number of tracks (3 in this case)
    features = torch.zeros((n_nodes, N_TRACKS), dtype=torch.float32)

    # Assign one-hot encoded track information as features
    for i, track_idx in enumerate(tracks):
        features[i, track_idx] = 1.0

    return features

def graph_from_tensor(s_tensor):
    bars = []

    # Iterate over bars and construct a graph for each bar
    for i in range(s_tensor.size(1)):  # Adjusted to iterate over timesteps
        bar = s_tensor[:, i, :]  # Adjusted to select the correct slice

        if not torch.any(bar):
            bar[0, 0] = 1  # Add a fake activation if the bar is empty

        track_edges = get_track_edges(bar)
        onset_edges = get_onset_edges(bar)
        next_edges = get_next_edges(bar)
        edges = [track_edges, onset_edges, next_edges]

        edge_list = torch.cat([x for x in edges if x.numel() > 0]) if any(x.numel() > 0 for x in edges) else torch.LongTensor([[0], [0]])
        edge_index = edge_list[:, :2].t().contiguous()
        attrs = edge_list[:, 2:] if edge_list.numel() > 0 else torch.Tensor([[0, 0]])

        edge_attrs = torch.zeros(attrs.size(0), s_tensor.shape[-1] + 1)
        edge_attrs[:, 0] = attrs[:, 0]
        edge_attrs[torch.arange(edge_attrs.size(0)), attrs.long()[:, 1] + 1] = 1

        node_features = get_track_features(bar)
        num_nodes = torch.sum(bar, dtype=torch.long)

        bars.append(Data(edge_index=edge_index, edge_attrs=edge_attrs, num_nodes=num_nodes, node_features=node_features).to(s_tensor.device))

    graph, _, _ = collate(Data, data_list=bars, increment=True, add_batch=True)
    graph.bars = graph.batch  # Rename to avoid conflict with DataLoader's batch

    return graph

def get_node_labels(s_tensor, ones_idxs):
    # Build a tensor which has node labels in place of each activation in the
    # stucture tensor
    labels = torch.zeros_like(s_tensor, dtype=torch.long, 
                              device=s_tensor.device)
    n_nodes = len(ones_idxs[0])
    labels[ones_idxs] = torch.arange(n_nodes, device=s_tensor.device)
    return labels


In [32]:
class PolyphemusDataset(Dataset):
    def __init__(self, sequences):
        self.sequences = sequences

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        seq_c_tensor, seq_s_tensor = self.sequences[idx]
        c_tensor = torch.tensor(seq_c_tensor, dtype=torch.float)
        s_tensor = torch.tensor(seq_s_tensor, dtype=torch.float)

        # Convert to one-hot encoding if necessary or directly use as features
        # Here, assuming c_tensor is already in the desired format
        # and s_tensor indicates the structure (active/inactive timesteps)

        # Build graph structure from structure tensor
        # Note: The function `graph_from_tensor` should be defined elsewhere in the project
        # and is responsible for converting the structure tensor into a graph object
        # compatible with the GNN model.
        s_tensor = s_tensor.unsqueeze(0)
        graph = graph_from_tensor(s_tensor)

        graph.c_tensor = c_tensor
        graph.s_tensor = s_tensor

        return graph

# Example usage
sample_data = generate_sample_data(N_SAMPLES, N_TIMESTEPS)
sequences = preprocess_data(sample_data, n_bars=2, resolution=8)
dataset = PolyphemusDataset(sequences)

seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)
seq_c_tensor shape: (100, 16, 5, 2), seq_s_tensor shape: (100, 16)


In [33]:
dataset[0]

Data(edge_index=[2, 3000], edge_attrs=[3000, 17], num_nodes=1600, node_features=[1600, 4], batch=[1600], ptr=[101], bars=[1600], c_tensor=[100, 16, 5, 2], s_tensor=[1, 100, 16])

In [34]:
# Print the contents of the first few items in the dataset
for i in range(min(len(dataset), 5)):  # Just as an example, print first 5 items
    graph = dataset[i]
    print(f"Graph {i}:")
    print(f"c_tensor shape: {graph.c_tensor.shape}")
    print(f"s_tensor shape: {graph.s_tensor.shape}")
    # Add any other properties you wish to inspect
    print("----------")

Graph 0:
c_tensor shape: torch.Size([100, 16, 5, 2])
s_tensor shape: torch.Size([1, 100, 16])
----------
Graph 1:
c_tensor shape: torch.Size([100, 16, 5, 2])
s_tensor shape: torch.Size([1, 100, 16])
----------
Graph 2:
c_tensor shape: torch.Size([100, 16, 5, 2])
s_tensor shape: torch.Size([1, 100, 16])
----------
Graph 3:
c_tensor shape: torch.Size([100, 16, 5, 2])
s_tensor shape: torch.Size([1, 100, 16])
----------
Graph 4:
c_tensor shape: torch.Size([100, 16, 5, 2])
s_tensor shape: torch.Size([1, 100, 16])
----------
