In [1]:
import sys
import os

import torch as T
import numpy as np
import random

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..")))

## Dataset Creation

In [2]:
from torch.utils.data import TensorDataset, random_split, Subset, ConcatDataset

from datasets.wcst import WCST

### 1. Dataset Hyperparameters

In [3]:
BATCH_SIZE = 64
TOTAL_BATCHES = 500
MIN_SWITCH_BATCHES = 5
MAX_SWITCH_BATCHES = 15

### 2. Creating Dataset

In [4]:
def create_wcst_dataset(wcst, total_batches=5000, batch_size=64, fixed_context=None, allow_switch=True):
    """
    Create a dataset of (encoder, decoder, target) tensors.
    
    Args:
        wcst: WCST environment instance
        total_batches: total number of batches to generate
        batch_size: number of samples per batch
        fixed_context: category name to fix the context (e.g., 0 (color), 1 (shape), 2 (quantity))
        allow_switch: if True, allows random context switching
    """
    encoder_inputs = []
    decoder_inputs = []
    targets = []

    batches_since_switch = 0
    switch_threshold = random.randint(MIN_SWITCH_BATCHES, MAX_SWITCH_BATCHES)

    # --- Fix context if specified ---
    if fixed_context is not None:
        wcst.set_context(fixed_context)
        print(f"[!] Fixed context mode: {wcst.category_feature}")

    for _ in range(total_batches):
        # Switch context if allowed
        if allow_switch and batches_since_switch >= switch_threshold:
            wcst.context_switch()
            print(f"[!] Context Switched - Using Category {wcst.category_feature}")
            batches_since_switch = 0
            switch_threshold = random.randint(MIN_SWITCH_BATCHES, MAX_SWITCH_BATCHES)

        X_batch, t_batch = next(wcst.gen_batch())

        for i in range(batch_size):
            encoder_input = T.tensor(np.array(X_batch[i]), dtype=T.long).flatten()
            decoder_input = T.tensor(t_batch[i][:-1], dtype=T.long)
            target = T.tensor([t_batch[i][-1]], dtype=T.long)

            encoder_inputs.append(encoder_input)
            decoder_inputs.append(decoder_input)
            targets.append(target)

        batches_since_switch += 1

    encoder_tensor = T.stack(encoder_inputs)
    decoder_tensor = T.stack(decoder_inputs)
    target_tensor = T.stack(targets)
    return TensorDataset(encoder_tensor, decoder_tensor, target_tensor)

In [5]:
# Example usage
contexts = [0, 1, 2]
train_datasets = {}

VAL_RATIO = 0.2
TEST_RATIO = 0.2

print("\n[DATASET CREATION]")
train_splits, val_splits, test_splits = [], [], []

for ctx in contexts:
    print(f"\n[CONTEXT {ctx}] Generating dataset...")
    wcst = WCST(BATCH_SIZE)
    
    # Create dataset for this fixed context
    full_dataset = create_wcst_dataset(
        wcst,
        total_batches=TOTAL_BATCHES // len(contexts),
        batch_size=BATCH_SIZE,
        fixed_context=ctx,
        allow_switch=False
    )

    # Compute split sizes
    total_size = len(full_dataset)
    val_size = int(total_size * VAL_RATIO)
    test_size = int(total_size * TEST_RATIO)
    train_size = total_size - val_size - test_size

    # Split into train/val/test for this context
    train_subset, val_subset, test_subset = random_split(full_dataset, [train_size, val_size, test_size])

    train_datasets[ctx] = train_subset
    val_splits.append(val_subset)
    test_splits.append(test_subset)

# Combine validation & test splits across all contexts
validation_dataset = ConcatDataset(val_splits)
test_dataset = ConcatDataset(test_splits)

# Logging
print("\nDataset creation complete!")
for ctx, ds in train_datasets.items():
    print(f"Train ({ctx}): {len(ds)}")
print(f"Validation (combined): {len(validation_dataset)}")
print(f"Test (combined): {len(test_dataset)}")


[DATASET CREATION]

[CONTEXT 0] Generating dataset...
[!] Fixed context mode: 0

[CONTEXT 1] Generating dataset...
[!] Fixed context mode: 1

[CONTEXT 2] Generating dataset...
[!] Fixed context mode: 2

Dataset creation complete!
Train (0): 6376
Train (1): 6376
Train (2): 6376
Validation (combined): 6372
Test (combined): 6372


### 3. Save Datasets

In [6]:
def save_dataset(dataset, filename):
    """Recursively save a TensorDataset, Subset, or ConcatDataset to a .pt file."""

    def extract_tensors(ds):
        if isinstance(ds, TensorDataset):
            return ds.tensors

        elif isinstance(ds, Subset):
            base_tensors = extract_tensors(ds.dataset)
            idx = ds.indices
            return tuple(t[idx] for t in base_tensors)

        elif isinstance(ds, ConcatDataset):
            parts = [extract_tensors(d) for d in ds.datasets]
            return tuple(T.cat(p, dim=0) for p in zip(*parts))

        else:
            raise TypeError(f"Unsupported dataset type: {type(ds)}")

    encoder_data, decoder_data, targets = extract_tensors(dataset)
    T.save((encoder_data, decoder_data, targets), filename)
    print(f"Saved {filename} ({len(targets)} samples)")

In [7]:
for context, dataset in train_datasets.items():
    save_dataset(dataset, f'../datasets/train_context_{context}.pt')

# --- Save validation and test datasets ---
save_dataset(validation_dataset, '../datasets/validation_dataset.pt')
save_dataset(test_dataset, '../datasets/test_dataset.pt')

Saved ../datasets/train_context_0.pt (6376 samples)
Saved ../datasets/train_context_1.pt (6376 samples)
Saved ../datasets/train_context_2.pt (6376 samples)
Saved ../datasets/validation_dataset.pt (6372 samples)
Saved ../datasets/test_dataset.pt (6372 samples)
