In [1]:
import fcwt
import numpy as np
import matplotlib.pyplot as plt
import pickle
from scipy import signal
from PIL import Image
import pandas as pd
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
from vit_pytorch.vit_3d import ViT
from scipy.ndimage import zoom
import torch.cuda.amp as amp

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# load dataset
x = pickle.load(open('data_preprocessed_python/s01.dat', 'rb'), encoding='latin1')
data = x['data']
labels = x['labels']

relevant_channels = data[:, :32, :]
relevant_labels = labels[:, :2]

classes = []
for trial in range(40):
    # 4 valence-arousal classes
    valence, arousal = relevant_labels[trial][0], relevant_labels[trial][1]
    cls = 0 if valence < 4.5 and arousal < 4.5 else \
                    1 if valence < 4.5 else \
                    2 if arousal < 4.5 else 3
    classes.append(cls)

In [3]:
# initialize constant variables
# paramaters for calculating cwt, not to be changed
fs = 128 
f0 = 4 # lowest frequency
f1 = 45 # highest frequency
fn = 32 # number of frequencies, match channel number for square frame
target_shape = (32, 64, 256)

# parameters for model training etc., may be changed to adjust balance between performance and cost
# TODO: consider evolutionary approach to finding optimal parameters
BATCH_SIZE = 1
IMAGE_SIZE = 2048
PATCH_SIZE = 64

In [22]:
total_cwt = np.zeros((1024, 8064))

for channel in range(32):    
    signal = relevant_channels[0][channel]
    _, current_cwt = fcwt.cwt(signal, fs, f0, f1, fn)
    start = channel * fn
    end = (channel + 1) * fn
    total_cwt[start:end, :] = abs(current_cwt)

In [None]:
# convert 2D time sample x channel-frequency format (x, y) to 3D channel x frequency x time format (x, y, z)
# each 'frame' in all_frames is a 2D image showing each channel's CWT value for that time sample
all_frames = []

for sample in range(8064):
    frame = np.zeros((32, 32))
    for channel in range(32):
        start = channel * fn
        end = (channel + 1) * fn
        frame[:, channel] = total_cwt[start:end, sample]

    # formatting for .gif format, not neccessary
    # append(frame) if creating data for training
    
    # normalize frame
    norm_frame = (frame - frame.min()) / (frame.max() - frame.min())
    # scale to 0-255
    scaled_frame = (norm_frame * 255).astype(np.uint8)

    # stack for RGB
    frame_rgb = np.stack((scaled_frame,) * 3, axis=0)
    frame_rgb = np.clip(frame_rgb, 0, 255).astype(np.uint8)

    if sample == 0:
        print("Frame shape:", frame_rgb.shape)

    all_frames.append(frame_rgb)

Frame shape: (3, 32, 32)


In [35]:
# write .gif file
from array2gif import write_gif
gif_dataset = np.array(all_frames)
print("Dataset shape:", gif_dataset.shape)
write_gif(gif_dataset, 'test.gif', fps = 128)

Dataset shape: (8064, 3, 32, 32)
