In [58]:
import os

import numpy as np
import pandas as pd

import torch
import torchvision

import matplotlib.pyplot as plt
import lightning as L

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from sklearn.model_selection import train_test_split


In [59]:
TEST = True

random_seed = 42
L.seed_everything(random_seed)

Seed set to 42


42

In [60]:
TRAIN_SIZE, VAL_SIZE, TEST_SIZE = 0.8, 0.1, 0.1

## TRAIN TEST SPLIT

In [61]:
dataset_path = os.path.join("dataset", "processed_concat_data")

In [62]:
fft_file = os.path.join(dataset_path, "torso_fft.npy")
label_file = os.path.join(dataset_path, "torso_label.npy")

In [63]:
prefix = "torso_"
postfix = "_fft"

train_val_test_10000_5classes_file = os.path.join(dataset_path, prefix + "train_val_test" + postfix + ".npz")
train_val_test_10000_5classes_choice_idx_file = os.path.join(dataset_path, prefix + "train_val_test_choice_idx" + postfix + ".npz")

# fft_train_file = os.path.join(dataset_path, prefix + "train" + postfix + ".npy")
# label_train_file = os.path.join(dataset_path, prefix + "train_label.npy")

# fft_val_file = os.path.join(dataset_path, prefix + "val" + postfix + ".npy")
# label_val_file = os.path.join(dataset_path, prefix + "val_label.npy")

# fft_test_file = os.path.join(dataset_path, prefix + "test" + postfix + ".npy")
# label_test_file = os.path.join(dataset_path, prefix + "test_label.npy")


In [64]:
fft_data = np.load(fft_file)
label = np.load(label_file)

In [81]:
np.unique(label, return_counts=True)

(array([1., 2., 3., 4., 5.]),
 array([26036, 26000,  8690, 24399, 33686], dtype=int64))

In [65]:
fft_data_down_samp_list = []
label_samp_list = []
choice_samp_list = []

for i in range(1, 1+5):
    one_class_idx = np.where(label == i)[0]
    choice_idx_list = np.random.choice(
        one_class_idx, 
        min(10000, len(one_class_idx)), 
        replace=False)
    
    label_samp_list.append(label[choice_idx_list])

    fft_data_down_samp_list.append(fft_data[choice_idx_list])
    choice_samp_list.append(choice_idx_list)

choice_samp = np.concatenate(choice_samp_list, axis=0)
fft_data_down_samp = np.concatenate(fft_data_down_samp_list, axis=0)
label_samp = np.concatenate(label_samp_list, axis=0)

In [66]:
print(fft_data_down_samp.shape, label_samp.shape)

(48690, 6, 257) (48690,)


In [67]:

train_val_data, test_data, train_val_label, test_label, train_val_choice, test_choice = \
    train_test_split(fft_data_down_samp, label_samp, choice_samp, test_size=TEST_SIZE, stratify=label_samp, shuffle=True)

train_data, val_data, train_label, val_label, train_choice, val_choice = \
    train_test_split(train_val_data, train_val_label, train_val_choice, test_size=VAL_SIZE / (TRAIN_SIZE + VAL_SIZE), stratify=train_val_label, shuffle=True)

print("train_data.shape, train_label.shape", train_data.shape, train_label.shape)
print("val_data.shape, val_label.shape", val_data.shape, val_label.shape)
print("test_data.shape, test_label.shape", test_data.shape, test_label.shape)

train_data.shape, train_label.shape (38952, 6, 257) (38952,)
val_data.shape, val_label.shape (4869, 6, 257) (4869,)
test_data.shape, test_label.shape (4869, 6, 257) (4869,)


In [82]:
additional_choice_idx_list = np.delete(np.arange(len(label)), choice_samp, axis=0)
np.unique(label[additional_choice_idx_list], return_counts=True)


(array([1., 2., 4., 5.]), array([16036, 16000, 14399, 23686], dtype=int64))

In [83]:
additional_fft_data_list = []
additional_label_list = []
additional_choice_list = []

for i in range(1, 1+5):
    one_class_idx = np.where(label[additional_choice_idx_list] == i)[0]

    choice_idx_list = np.random.choice(
        one_class_idx, 
        min(5000, len(one_class_idx)), 
        replace=False)
    
    additional_label_list.append((label[additional_choice_idx_list])[choice_idx_list])
    additional_fft_data_list.append(fft_data[additional_choice_idx_list][choice_idx_list])
    additional_choice_list.append(additional_choice_idx_list[choice_idx_list])

additional_choice = np.concatenate(additional_choice_list, axis=0)
additional_fft_data = np.concatenate(additional_fft_data_list, axis=0)
additional_label = np.concatenate(additional_label_list, axis=0)

In [35]:
# np.save(fft_train_file, train_data)
# np.save(label_train_file, train_label)
# np.save(fft_val_file, val_data)
# np.save(label_val_file, val_label)
# np.save(fft_test_file, test_data)
# np.save(label_test_file, test_label)

In [84]:
np.savez(train_val_test_10000_5classes_file, 
         train_data=train_data, 
         train_label=train_label,
         val_data=val_data,
         val_label=val_label,
         test_data=test_data,
         test_label=test_label,
         additional_data=additional_fft_data,
         additional_label=additional_label)

np.savez(train_val_test_10000_5classes_choice_idx_file,
         train_choice=train_choice,
         val_choice=val_choice,
         test_choice=test_choice,
         additional_choice=additional_choice)