In [1]:
import json

In [None]:
with open('data/TrainData.json') as f:
    train = json.load(f)
with open('data/TestData.json') as f:
    test = json.load(f)

In [None]:
from collections import defaultdict
reason_group_dict = defaultdict(lambda: 'R8: Other - The new sentence is not coherent to the context due to other reasons that are not in this list')
reason_group_dict.update({
    1: 'R1: Sense - The sentence doesn’t make sense',
    2: 'R2: Entity connection - The new sentence discusses an entity which has not been introduced yet',
    3: 'R3: Discourse relation - The relation between this sentence and previous ones doesn’t make sense',
    4: 'R4: Data consistency - The new sentence contains information inconsistent with previous presented data',
    5: 'R5: World knowledge - The new sentence contains information inconsistent with your knowledge about the world',
    6: 'R6: Data relevance - The new sentence is not relevant to previous data in the story',
    7: 'R7: Title relevance - The new sentence is not relevant to the topic'
})

In [None]:
def process_data(data: dict) -> list[dict]:
    out = []
    for _, inst in data.items():
        context = [inst['IncrementalData']['sentences'][0]]
        for idx, sent in enumerate(inst['IncrementalData']['sentences'][1:], start=1):
            incoherent_reasons = set()
            for _, reasons in inst['IncrementalData']['reasons'].items():
                if str(idx) in reasons:
                    incoherent_reasons.update(reasons[str(idx)])
            out.append({
                "context": " ".join(context),
                "sentence": sent,
                "is_coherent": len(incoherent_reasons) == 0,
                "edit": [{'reason': reason_group_dict[reason]} for reason in incoherent_reasons]
            })
            context.append(sent)
    return out

In [None]:
train_processed = process_data(train)
test_processed = process_data(test)
with open('data/TrainDataConverted.json', 'w') as f:
    json.dump(train_processed, f, ensure_ascii=False, indent=4)
with open('data/TestDataConverted.json', 'w') as f:
    json.dump(test_processed, f, ensure_ascii=False, indent=4)

In [2]:
with open('data/TrainDataConverted.json') as f:
    train_processed = json.load(f)
with open('data/TestDataConverted.json') as f:
    test_processed = json.load(f)
print(len(train_processed))
print(len(test_processed))

2499
293
