# Notebook: Create 5 Splits

## Packages

In [None]:
from iterstrat.ml_stratifiers import MultilabelStratifiedKFold
from collections import Counter
import numpy as np
import random
import json

## Parameters

In [None]:
DATASET_PATH = "dataset_total/dataset_filtered_3000.json"
N_FOLDS = 6

In [None]:
CRITERIA_RS = "tag_with_polarity"
ASPECTS = ["SERVICE", "FOOD", "GENERAL-IMPRESSION", "AMBIENCE", "PRICE"]
POLARITIES = ["POSITIVE", "NEGATIVE", "NEUTRAL"]
MENTIONING_TYPE = ["implicit", "explicit"]
COMBINATIONS = [f"{aspect}-{polarity}" for aspect in ASPECTS for polarity in POLARITIES]
COMBINATIONS

## Code

### Load Data

In [None]:
with open(DATASET_PATH, 'r', encoding='utf-8') as json_file:
    dataset = json.load(json_file)
len(dataset)

In [None]:
[tag[CRITERIA_RS] for tag in dataset[1]["tags"]]

In [None]:
labels_one_hot = []
for i in range(len(dataset)):
    tags_in_example = list(set([tag[CRITERIA_RS] for tag in dataset[i]["tags"]]))
    #print(tags_in_example)
    one_hot_encoded_combination = np.array([1 if tag in tags_in_example else 0 for tag in COMBINATIONS])
    #print(one_hot_encoded_combination)
    labels_one_hot.append(one_hot_encoded_combination)

### Split

In [None]:
random_state = 0
found_balanced_split = False

while found_balanced_split == False:
    mskf = MultilabelStratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=random_state)

    split_sizes = []
    idx = 0
    for train_index, test_index in mskf.split(dataset, labels_one_hot):
        test_dataset = [dataset[i] for i in test_index]
        print(len(test_dataset), Counter(
            [tag["label"] for example in test_dataset for tag in example["tags"]]))
        split_sizes.append(len(test_dataset))
        with open(f"../07 train models/real/split_{idx}.json", 'w', encoding='utf-8') as split_file:
            json.dump(test_dataset, split_file, ensure_ascii=False)
        idx += 1

    if any(item != 500 for item in split_sizes) == False:
        print(split_sizes, random_state)
        found_balanced_split = True
    random_state += 1