In [None]:
from skmultilearn.model_selection import iterative_train_test_split
X_train, y_train, X_test, y_test = iterative_train_test_split(x, y, test_size = 0.1)

In [8]:
import pickle
import fnmatch
import os
from sklearn.model_selection import train_test_split
from pathlib import Path
import cv2
from datamodule import Datamodule
from tqdm import tqdm
import pandas as pd
import numpy as np
from params import RANDOM_SEED, LocationConfig, CreateDataConfig, TrainingConfig

In [3]:
def create_new_data_directories(path):
    Path().mkdir(exist_ok=True, parents=True)
    Path(path + 'train').mkdir(exist_ok=True, parents=True)
    Path(path + 'test').mkdir(exist_ok=True, parents=True)
    
    
def get_short_video_name(videoNames):
    ShortVideoName = []
    for videoName in videoNames.values:
        ShortVideoName.append(videoName.split('.')[0])
    return ShortVideoName

def create_mean_video_name_df(df):
    cols = ['ValueExtraversion','ValueAgreeableness','ValueConscientiousness','ValueNeurotisicm','ValueOpenness','ShortVideoName']
    grouped_df = df[cols].groupby('ShortVideoName')
    mean_df = grouped_df.mean()
    mean_df = mean_df.reset_index()
    return mean_df

In [4]:
create_new_data_directories(LocationConfig.shuffle_data)

df = pd.read_csv(LocationConfig.labels + 'bigfive_labels.csv')

df['ShortVideoName'] = get_short_video_name(df['VideoName'])

mean_df = create_mean_video_name_df(df)
mean_df.to_csv(LocationConfig.labels + 'bigfive_labels_mean.csv')
mean_df = mean_df.set_index('ShortVideoName')

X_train, X_test = train_test_split(
    np.array(mean_df.index), 
    test_size=CreateDataConfig.test_size_ratio,
    random_state=RANDOM_SEED
)
images_dict_train = {'X':[], 'Y':[]}
images_dict_test = {'X':[], 'Y':[]}
total_files = len(fnmatch.filter(os.listdir(LocationConfig.crop_data), '*.jpg'))
for image_path in tqdm(Path(LocationConfig.crop_data).glob('*.jpg'), total=total_files):
    X = cv2.imread(str(image_path), cv2.IMREAD_GRAYSCALE) 
    X = np.expand_dims(X, axis=2) 
    image_group = image_path.name.split('.')[0]
    image_no = image_path.name.split('.')[2][-5:]
    Y = mean_df.loc[image_group].values
    if CreateDataConfig.classification:
        Y = list(np.where(Y>CreateDataConfig.Y_threshold, 1, 0))

    if image_group in X_test:
        images_dict_test['X'].append(X)
        images_dict_test['Y'].append(Y)
    else:
        images_dict_train['X'].append(X)
        images_dict_train['Y'].append(Y)

with open(LocationConfig.shuffle_data + 'train/train.pickle', 'wb') as handle:
    pickle.dump(images_dict_train, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open(LocationConfig.shuffle_data + 'test/test.pickle', 'wb') as handle:
    pickle.dump(images_dict_test, handle, protocol=pickle.HIGHEST_PROTOCOL)

100%|██████████| 26604/26604 [00:05<00:00, 4971.94it/s]


In [22]:
train_data_path = Path(LocationConfig.new_data + 'train')
test_data_path = Path(LocationConfig.new_data + 'test')
dm = Datamodule(
        batch_size=TrainingConfig.batch_size,
        train_dir=train_data_path,
        val_dir=test_data_path,
        )
# dm.setup(val_only=True)
dm.setup()

file: new_data/train/train.pickle
file: new_data/test/test.pickle


In [23]:
print('train:')
acc_class_global_0 = 0
i=0
for batch in tqdm(dm.train_dataloader()):
    X, Y = batch['normalized'], batch['label']
    Y_pred = np.zeros_like(Y)
    acc_class_0 = np.sum(Y_pred == np.array(Y), axis=0) / len(Y)
    acc_class_global_0 += acc_class_0
    i+=1
acc_class_global_0 /= i
print(acc_class_global_0)
print(acc_class_global_0.mean())

print('test:')
acc_class_global_0 = 0
i=0
for batch in tqdm(dm.val_dataloader()):
    X, Y = batch['normalized'], batch['label']
    Y_pred = np.zeros_like(Y)
    acc_class_0 = np.sum(Y_pred == np.array(Y), axis=0) / len(Y)
    acc_class_global_0 += acc_class_0
    i+=1
acc_class_global_0 /= i
print(acc_class_global_0)
print(acc_class_global_0.mean())

train:


100%|██████████| 190/190 [00:02<00:00, 85.92it/s]


[0.56551927 0.285961   0.40653195 0.41715226 0.2837641 ]
0.39178571428571424
test:


100%|██████████| 19/19 [00:00<00:00, 72.80it/s]

[0.52679211 0.3142343  0.40442206 0.41275567 0.28363975]
0.38836877931214175



