In [7]:
from sklearn.model_selection import train_test_split
import torch
import pandas as pd
import os, import_ipynb
from data import *

importing Jupyter notebook from data.ipynb


In [2]:
#split dataframe train test
def split_train_test_stratified(test_size, valid_size):
    df = pd.read_csv('processed_shape_data.csv')
    df['num_label'] = df.positive*0 + df.neutral*1 + df.negative*2
    X = df[['sessionID','path','labels','positive','neutral','negative']]
    # df['num_label'] = df.anger*0 + df.frustration*1 + df.sad*2 + df.neutral*3 + df.happy*4 + df.excited*5
    # X = df[['sessionID','path','labels','anger','frustration','sad','neutral','happy','excited']]
    Y = df["num_label"]
    x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=test_size, stratify=Y,random_state=0)

    # Step 1: Split into train and temp (temp includes both validation and test sets)
    x_train, x_temp, y_train, y_temp = train_test_split(X, Y, test_size=(test_size + valid_size), stratify=Y, random_state=0)

    # Step 2: Split the temp set into validation and test sets
    x_valid, x_test, y_valid, y_test = train_test_split(x_temp, y_temp, test_size=(test_size/(valid_size + test_size)), stratify=y_temp, random_state=0)

    return x_train, x_valid, x_test, y_train, y_valid, y_test

In [3]:
#create label files.pt to be output
def create_train_test_label_file(x_train, x_valid, x_test, source):
    label_train = x_train[['positive','neutral','negative']]
    # label_train = x_train[['anger','frustration','sad','neutral','happy','excited']]
    tensor_label_train = torch.tensor(label_train.values)
    torch.save(tensor_label_train,os.path.join(source,'label_train_relabelfru_removesad.pt'))

    label_valid = x_valid[['positive','neutral','negative']]
    # label_test = x_test[['anger','frustration','sad','neutral','happy','excited']]
    tensor_label_valid = torch.tensor(label_valid.values)
    torch.save(tensor_label_valid,os.path.join(source,'label_valid_relabelfru_removesad.pt'))

    label_test = x_test[['positive','neutral','negative']]
    # label_test = x_test[['anger','frustration','sad','neutral','happy','excited']]
    tensor_label_test = torch.tensor(label_test.values)
    torch.save(tensor_label_test,os.path.join(source,'label_test_relabelfru_removesad.pt'))

In [4]:
#create data files.pt to be input
def create_train_test_data_file(x_train, x_valid, x_test, source):
    time_shape = torch.load(x_train.reset_index().path[0]).shape[-1]
    tmp_file = torch.empty((len(x_train),2,40,time_shape))
    count = 0
    for file in x_train['path']:
        tmp_file[count] = torch.load(file)
        count+=1
    torch.save(tmp_file,os.path.join(source,'data_train_relabelfru_removesad.pt'))

    tmp_file = torch.empty((len(x_valid),2,40,time_shape))
    count = 0
    for file in x_valid['path']:
        tmp_file[count] = torch.load(file)
        count+=1
    torch.save(tmp_file,os.path.join(source,'data_valid_relabelfru_removesad.pt'))

    tmp_file = torch.empty((len(x_test),2,40,time_shape))
    count = 0
    for file in x_test['path']:
        tmp_file[count] = torch.load(file)
        count+=1
    torch.save(tmp_file,os.path.join(source,'data_test_relabelfru_removesad.pt'))

In [5]:
def main_(test_size, valid_size, source):
    x_train, x_valid, x_test, y_train, y_valid, y_test = split_train_test_stratified(test_size, valid_size)
    create_train_test_label_file(x_train, x_valid, x_test, source)
    create_train_test_data_file(x_train, x_valid, x_test,source)
    return 

In [10]:
root = os.path.join(link_to_data(), 'model','train_data')
main_(valid_size = 0.2, test_size = 0.1, source = root)

In [11]:
def print_shape(source):
    list_ = ['data_train_relabelfru_removesad','data_valid_relabelfru_removesad','data_test_relabelfru_removesad','label_train_relabelfru_removesad','label_valid_relabelfru_removesad','label_test_relabelfru_removesad']
    for name in list_:
        name_file = os.path.join(source,name+'.pt')
        tensor = torch.load(name_file)
        print(name,' has shape: ',tensor.shape)

In [12]:
print_shape(root)

data_train_relabelfru_removesad  has shape:  torch.Size([10576, 2, 40, 900])
data_valid_relabelfru_removesad  has shape:  torch.Size([3022, 2, 40, 900])
data_test_relabelfru_removesad  has shape:  torch.Size([1512, 2, 40, 900])
label_train_relabelfru_removesad  has shape:  torch.Size([10576, 3])
label_valid_relabelfru_removesad  has shape:  torch.Size([3022, 3])
label_test_relabelfru_removesad  has shape:  torch.Size([1512, 3])
