In [13]:
import numpy as np
from sklearn.model_selection import StratifiedKFold
from sklearn.model_selection import ParameterGrid

In [14]:
def load_arrays():
    """Return the pre-saved arrays"""
    path = "../src/"

    # Student answers with the corresponding scores
    answers = dict(
        q1_answers=np.load(f"{path}Arrays/answers.npy", allow_pickle=True),
        q1_scores=np.load(f"{path}Arrays/scores.npy"),
        q1_sequences=np.load(f"{path}Arrays/sequences.npy"),

        q2_answers=np.load(f"{path}Arrays2/answers.npy", allow_pickle=True),
        q2_scores=np.load(f"{path}Arrays2/scores.npy"),
        q2_sequences=np.load(f"{path}Arrays2/sequences.npy"),
    )

    # GloVe, fastText and LDA embeddings
    embeddings = dict(
        q1_glove=np.load(f"{path}Arrays/embedding_matrix_glove.npy"),
        q1_fasttext=np.load(f"{path}Arrays/embedding_matrix_fasttext.npy"),
        q1_lda=np.load(f"{path}Arrays/embedding_matrix_lda.npy"),

        q2_glove=np.load(f"{path}Arrays2/embedding_matrix_glove.npy"),
        q2_fasttext=np.load(f"{path}Arrays2/embedding_matrix_fasttext.npy"),
        # q2_lda = np.load("Arrays2/embedding_matrix_lda.npy")
    )

    return answers, embeddings

In [15]:
answers, embeddings = load_arrays()

In [37]:
KFOLDS = 5


def score(onehot):
    """Turn one hot encoding/softmax output into actual score"""
    return [list(x).index(max(x)) for x in onehot]


def get_train_sequences(n, features, labels):
    """Get training and validation sequences based on kfolds cross validation"""
    y = score(labels)

    kf = StratifiedKFold(KFOLDS, shuffle=True, random_state=1)
    split = list(kf.split(features, y))[n]

    x_train = np.array(features[split[0]])
    x_valid = np.array(features[split[1]])

    y_train = np.array(labels[split[0]])
    y_valid = np.array(labels[split[1]])

    return x_train, x_valid, y_train, y_valid

In [38]:
x_train, x_valid, y_train, y_valid = get_train_sequences(1, answers['q1_answers'],  answers['q1_scores'])

In [None]:
answers['q1_scores']

array([[0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 0., 1.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1., 0.],
       [0., 1., 0.],
       [1., 0., 0.],
       [0., 1