In [61]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
import functional as F
data = scipy.io.loadmat('../data/TsukubaHandSize24x24.mat')
data = data['data']

In [62]:
def split_train_test(data, test_ratio=0.2):
    num_people = data.shape[5]
    num_test = int(num_people * test_ratio)
    test_data = data[:, :, :, :, :, :num_test]
    train_data = data[:, :, :, :, :, num_test:]
    return train_data, test_data

def reshape_data(data, n_samples):
    data = data.reshape(24, 24, 60*7, 30, n_samples)
    data = np.transpose(data, (0, 1, 4, 2, 3))
    data = np.reshape(data, (24, 24, 420 * n_samples, 30))
    return data

def split_and_reshape(data):
    data = np.split(data, 30, axis=3)
    data = [arr.squeeze(axis=3) for arr in data]
    return data

def flatten_and_transpose(data):
    data = [arr.reshape(24 * 24, arr.shape[2]) for arr in data]
    data = [arr.T for arr in data]
    return data

def plot_first_images(first_images):
    grid_rows = 6
    grid_cols = 5
    fig, axes = plt.subplots(grid_rows, grid_cols, figsize=(12, 12))

    for i, img in enumerate(first_images):
        row = i // grid_cols
        col = i % grid_cols
        ax = axes[row, col]
        ax.imshow(img, cmap='gray')
        ax.set_title(f'Class {i+1}')
        ax.axis('off')

    plt.subplots_adjust(hspace=0.5, wspace=0.5)
    plt.show()

X_train, X_test = split_train_test(data)
X_train = reshape_data(X_train, 80)
X_test = reshape_data(X_test, 20)

X_train = split_and_reshape(X_train)
X_test = split_and_reshape(X_test)

# plot train images
# first_images = [cls_arr[:, :, 0] for cls_arr in X_train]
# plot_first_images(first_images)

# # plot test images
# first_images = [cls_arr[:, :, 0] for cls_arr in X_test]
# plot_first_images(first_images)

X_train = flatten_and_transpose(X_train)
X_test = flatten_and_transpose(X_test)

print('Number of train classes ',len(X_train))
print('Shape of train class ',X_train[0].shape)
print('Numver of test classes ',len(X_test))
print('Shape of test class ', X_test[0].shape)


Number of train classes  30
Shape of train class  (33600, 576)
Numver of test classes  30
Shape of test class  (8400, 576)
