In [8]:
import sys
import os

import torch as T
import numpy as np

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

## Dataset Creation

In [9]:
from torch.utils.data import TensorDataset, random_split

from datasets.wcst import WCST

### 1. Dataset Hyperparameters

In [23]:
BATCH_SIZE = 32
BATCHES_PER_CONTEXT = 100
N_CONTEXT_SWITCHES = 1

### 2. Creating Dataset

In [24]:
wcst = WCST(BATCH_SIZE)
data = []
targets = []

for i in range(N_CONTEXT_SWITCHES):
    for j in range(BATCHES_PER_CONTEXT):
        X, t = next(wcst.gen_batch())
        for k in range(BATCH_SIZE):
            data.append(X[k])
            targets.append(t[k])
    wcst.context_switch()

data = T.tensor(np.array(data), dtype=T.long)
targets = T.tensor(np.array(targets), dtype=T.long)
dataset = TensorDataset(data, targets)

train_dataset, validation_dataset, test_dataset = random_split(
    dataset, [0.6, 0.2, 0.2]
)

### 3. Save Datasets

In [25]:
def save_subset(subset, filename):
    data = subset.dataset.tensors[0][subset.indices]
    targets = subset.dataset.tensors[1][subset.indices]
    T.save((data, targets), filename)

In [26]:
save_subset(train_dataset, '../datasets/train_dataset.pt')
save_subset(validation_dataset, '../datasets/validation_dataset.pt')
save_subset(test_dataset, '../datasets/test_dataset.pt')