In [1]:
from typing import *
from glob import glob
import gzip
import re
import numpy as np
import random
import csv
import json
from itertools import chain
from collections import defaultdict
import gc
import os
from tokenizers import BertWordPieceTokenizer
from http import HTTPStatus

In [2]:
# we load the constraints data
constraints_def = {}

mapping_to_wikidata = {
    'property id': '^<http://www.wikidata.org/entity/P2302>',
    ' constraint type id': '<http://www.wikidata.org/entity/P2302>',
    'regex': '<http://www.wikidata.org/entity/P1793>',
    'exceptions': '<http://www.wikidata.org/entity/P2303>',
    'group by': '<http://www.wikidata.org/entity/P2304>',
    'items': '<http://www.wikidata.org/entity/P2305>',
    'property': '<http://www.wikidata.org/entity/P2306>',
    'namespace': '<http://www.wikidata.org/entity/P2307>',
    'class': '<http://www.wikidata.org/entity/P2308>',
    'relation': '<http://www.wikidata.org/entity/P2309>',
    'minimal date': '<http://www.wikidata.org/entity/P2310>',
    'maximum date': '<http://www.wikidata.org/entity/P2311>',
    'maximum value': '<http://www.wikidata.org/entity/P2312>',
    'minimal value': '<http://www.wikidata.org/entity/P2313>',
    'status': '<http://www.wikidata.org/entity/P2316>',
    'separator': '<http://www.wikidata.org/entity/P4155>',
    'scope': '<http://www.wikidata.org/entity/P4680>'
}
with open('/content/constraints.tsv', newline='') as fp:
    for row in csv.DictReader(fp, dialect='excel-tab'):
        predicates = []
        objects = []
        for k,vs in row.items():
            if k != 'constraint id':
                for v in vs.split(' '):
                    v = v.strip()
                    if v:
                        predicates.append(mapping_to_wikidata[k])
                        objects.append(v)
        constraints_def[row['constraint id']] = {'predicates': predicates, 'objects': objects}

In [3]:
# we load and preprocess the data. We encode string with integers to allow the data to fit easily into the main memory

class GlobalIntEncoder:
    def __init__(self):
        self._encoding = {
            '': 0
        }

    def encode(self, value: str):
        if value is None:
            value = ''
        value = str(value)
        if value not in self._encoding:
            self._encoding[value] = len(self._encoding)
        return self._encoding[value]

    def save(self, file: str):
        with open(file, 'wt') as fp:
            fp.writelines(l + '\n' for l in self._encoding.keys())

encoder = GlobalIntEncoder()

_relation_to_predicate = {
    encoder.encode('<http://www.wikidata.org/entity/Q21503252>'): [encoder.encode('<http://www.wikidata.org/entity/P31>')],
    encoder.encode('<http://www.wikidata.org/entity/Q21514624>'): [encoder.encode('<http://www.wikidata.org/entity/P279>')],
    encoder.encode('<http://www.wikidata.org/entity/Q30208840>'): [encoder.encode('<http://www.wikidata.org/entity/P31>'), encoder.encode('<http://www.wikidata.org/entity/P279>')],
}

def _convert_values(values: str) -> List[str]:
    return [v for v in (_convert_value(v.strip()) for v in values.split(' ')) if v]

def _convert_value(value: Optional[str], subject: Optional[str] = None, predicate: Optional[str] = None, obj: Optional[str] = None, other_subject: Optional[str] = None, other_predicate: Optional[str] = None, other_object: Optional[str] = None) -> Optional[str]:
    if value is None or value == '':
        return 0
    value = encoder.encode(value.replace('http://www.wikidata.org/prop/direct/', 'http://www.wikidata.org/entity/'))
    if value == subject:
        return encoder.encode('subject')
    elif value == predicate:
        return encoder.encode('predicate')
    elif value == obj:
        return encoder.encode('object')
    elif value == other_subject:
        return encoder.encode('other_subject')
    elif value == other_predicate:
        return encoder.encode('other_predicate')
    elif value == other_object:
        return encoder.encode('other_object')
    else:
        return value

def _read_entity_desc(line: List[str], desc_position: int) -> Dict[str,Any]:
    desc = line[desc_position].strip()
    result = {
            'entity_predicates': [],
            'entity_objects': [],
            'entity_labels': [],
            'http_content': ''
    }
    if not desc:
        return result
    try:
        desc = json.loads(desc)
    except ValueError:
        print('Invalid description: {}'.format(desc))
        return result
    if desc['type'] == 'page':
        try:
            status = "<http://www.w3.org/2011/http-statusCodes#{}>".format(HTTPStatus(desc['statusCode']).phrase.title().replace(' ', '').replace('-', ''))
            result['entity_predicates'].append(_convert_value('<http://wikiba.se/history/ontology#pageStatusCode>'))
            result['entity_objects'].append(_convert_value(status))
        except ValueError as e:
            print(e)
        result['http_content'] = desc['content']
    elif desc['type'] == 'entity':
        result['entity_labels'].extend(desc['labels'].values())
        for predicate, objects in desc['facts'].items():
            for obj in objects:
                result['entity_predicates'].append(_convert_value(predicate))
                result['entity_objects'].append(_convert_value(obj))
    else:
        print('Invalid description: {}'.format(result))
    return result

def load_dataset(file_path, max_size: int = 100000):
    dataset = {
        'constraint_id': [],
        'constraint_predicates': [],
        'constraint_objects': [],
        'subject': [],
        'predicate': [],
        'object': [],
        'object_text': [],
        'other_subject': [],
        'other_predicate': [],
        'other_object': [],
        'other_object_text': [],
        'subject_predicates': [],
        'subject_objects': [],
        'object_predicates': [],
        'object_objects': [],
        'other_entity_predicates': [],
        'other_entity_objects': [],
        'add_subject': [],
        'add_predicate': [],
        'add_object': [],
        'del_subject': [],
        'del_predicate': [],
        'del_object': []
    }
    with gzip.open(file_path, 'rt') as fp:
        for line_i, line in enumerate(fp):
            if line_i == max_size:
                break

            elements = line.split('\t')
            if elements[0] not in constraints_def:
                continue

            constraint = constraints_def[elements[0]]
            subject = _convert_value(elements[2])
            predicate = _convert_value(elements[3])
            obj = _convert_value(elements[4])
            other_subject = _convert_value(elements[5])
            other_predicate = _convert_value(elements[6])
            other_object = _convert_value(elements[7])
            add_subject = None
            add_predicate = None
            add_object = None
            del_subject= None
            del_predicate = None
            del_object = None
            entity = None
            i = 12
            while i < len(elements):
                if elements[i] == '<http://wikiba.se/history/ontology#addition>':
                    add_subject = elements[i - 3]
                    add_predicate = elements[i - 2]
                    add_object = elements[i - 1]
                elif elements[i] == '<http://wikiba.se/history/ontology#deletion>':
                    del_subject = elements[i - 3]
                    del_predicate = elements[i - 2]
                    del_object = elements[i - 1]
                else:
                    print('Unexpected entity: {}'.format(elements[i-3:i+1]))
                    continue
                i += 4

            subject_desc = _read_entity_desc(elements, -3)
            object_desc = _read_entity_desc(elements, -2)
            other_entity_desc = _read_entity_desc(elements, -1)
            if any(label in object_desc['http_content'] for label in subject_desc['entity_labels']):
                object_desc['entity_predicates'].append(_convert_value('<http://wikiba.se/history/ontology#pageContainsLabel>'))
                object_desc['entity_objects'].append(subject)
            if any(label in object_desc['http_content'] for label in other_entity_desc['entity_labels']):
                object_desc['entity_predicates'].append(_convert_value('<http://wikiba.se/history/ontology#pageContainsLabel>'))
                object_desc['entity_objects'].append(other_subject)
            if any(label in other_entity_desc['http_content'] for label in subject_desc['entity_labels']):
                other_entity_desc['entity_predicates'].append(_convert_value('<http://wikiba.se/history/ontology#pageContainsLabel>'))
                other_entity_desc['entity_objects'].append(subject)
            if any(label in other_entity_desc['http_content'] for label in object_desc['entity_labels']):
                other_entity_desc['entity_predicates'].append(_convert_value('<http://wikiba.se/history/ontology#pageContainsLabel>'))
                other_entity_desc['entity_objects'].append(obj)

            dataset['constraint_id'].append(_convert_value(elements[0]))
            dataset['constraint_predicates'].append([_convert_value(v) for v in constraint['predicates']])
            dataset['constraint_objects'].append([_convert_value(v) for v in constraint['objects']])
            dataset['subject'].append(subject)
            dataset['predicate'].append(predicate)
            if elements[4].startswith('<http://www.wikidata.org/entity/'):
                dataset['object'].append(obj)
                dataset['object_text'].append('')
            else:
                dataset['object'].append(0)
                dataset['object_text'].append(elements[4].split('^^')[0])
            dataset['other_subject'].append(other_subject)
            dataset['other_predicate'].append(other_predicate)
            if elements[7].startswith('<http://www.wikidata.org/entity/'):
                dataset['other_object'].append(other_object)
                dataset['other_object_text'].append('')
            else:
                dataset['other_object'].append(0)
                dataset['other_object_text'].append(elements[7].split('^^')[0])
            dataset['add_subject'].append(_convert_value(add_subject, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['add_predicate'].append(_convert_value(add_predicate, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['add_object'].append(_convert_value(add_object, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['del_subject'].append(_convert_value(del_subject, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['del_predicate'].append(_convert_value(del_predicate, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['del_object'].append(_convert_value(del_object, subject, predicate, obj, other_subject, other_predicate, other_object))
            dataset['subject_predicates'].append(subject_desc['entity_predicates'])
            dataset['subject_objects'].append(subject_desc['entity_objects'])
            dataset['object_predicates'].append(object_desc['entity_predicates'])
            dataset['object_objects'].append(object_desc['entity_objects'])
            dataset['other_entity_predicates'].append(other_entity_desc['entity_predicates'])
            dataset['other_entity_objects'].append(other_entity_desc['entity_objects'])
    return dataset

In [4]:
import zipfile
import os

# 1. Unzip the dataset
with zipfile.ZipFile('/content/constraint-corrections.zip', 'r') as zip_ref:
    zip_ref.extractall('/content/constraint-corrections')  # Extracts to /content/constraint-corrections/

# 2. Verify extraction (optional)
print("Extracted files:", os.listdir('/content/constraint-corrections'))

# 3. Modify your `load` function to use the correct path
def load(kind: str, targets: List[str]):
    result = defaultdict(list)
    for target in targets:
        print(f'Loading {target} {kind}')
        file_path = f'/content/constraint-corrections/constraint-corrections-{target}.tsv.gz.full.{kind}.tsv.gz'
        data = load_dataset(file_path)
        for k, v in data.items():
            result[k].extend(v)  # Keep as lists
    gc.collect()
    return result

# Now run your original code
target = '*'
if target == '*':
    targets = ['conflictWith', 'distinct', 'inverse', 'itemRequiresStatement', 'oneOf', 'single', 'type', 'valueRequiresStatement', 'valueType']
    train_dataset = load('train', targets)
    dev_dataset = load('dev', targets)
    test_dataset = {target: load('test', [target]) for target in targets}
else:
    train_dataset = load('train', [target])
    dev_dataset = load('dev', [target])
    test_dataset = {target: load('test', [target])}
gc.collect()

Extracted files: ['constraint-corrections-valueRequiresStatement.tsv.gz.full.test.tsv.gz', 'constraint-corrections-distinct.tsv.gz.full.train.tsv.gz', 'constraint-corrections-oneOf.tsv.gz.full.train.tsv.gz', 'constraint-corrections-type.tsv.gz.full.train.tsv.gz', 'constraint-corrections-valueType.tsv.gz.full.dev.tsv.gz', 'constraint-corrections-distinct.tsv.gz.full.dev.tsv.gz', 'constraint-corrections-single.tsv.gz.full.test.tsv.gz', 'constraint-corrections-valueType.tsv.gz.full.train.tsv.gz', 'constraint-corrections-type.tsv.gz.full.test.tsv.gz', 'constraint-corrections-valueRequiresStatement.tsv.gz.full.train.tsv.gz', 'constraint-corrections-single.tsv.gz.full.dev.tsv.gz', 'constraint-corrections-oneOf.tsv.gz.full.dev.tsv.gz', 'constraint-corrections-conflictWith.tsv.gz.full.test.tsv.gz', 'constraint-corrections-itemRequiresStatement.tsv.gz.full.test.tsv.gz', 'constraint-corrections-inverse.tsv.gz.full.test.tsv.gz', 'constraint-corrections-valueType.tsv.gz.full.test.tsv.gz', 'constra

0

In [8]:
# Convert all dataset values to numpy arrays, handling variable-length sequences
def convert_to_numpy(data):
    for key in data.keys():
        # For lists of numbers, convert directly
        if all(isinstance(x, (int, float)) for x in data[key]):
            data[key] = np.array(data[key])
        # For lists of lists (variable length sequences), convert to numpy object array
        else:
            data[key] = np.array(data[key], dtype=object)
    return data

# Apply conversion to all datasets
train_dataset = convert_to_numpy(train_dataset)
dev_dataset = convert_to_numpy(dev_dataset)
test_dataset = {target: convert_to_numpy(test_dataset[target]) for target in test_dataset}

In [9]:
# Prints some statistics about the dataset
from collections import defaultdict

def dataset_stats(dataset: Iterable[dict]):
    known_entities = [encoder.encode('subject'), encoder.encode('predicate'), encoder.encode('object'), encoder.encode('other_subject'), encoder.encode('other_predicate'), encoder.encode('other_object'), encoder.encode('constraint_predicate')]
    count = 0
    constraints = defaultdict(int)
    with_subject_desc = 0
    subject_desc_sum = 0
    with_object_desc = 0
    object_desc_sum = 0
    with_object_http_status = 0
    object_http_contains_label = 0
    with_other_triple = 0
    with_other_entity_desc = 0
    other_entity_desc_sum = 0
    with_other_entity_http_status = 0
    other_entity_http_contains_label = 0
    with_add_subject = 0
    with_add_predicate = 0
    with_add_object = 0
    with_del_subject = 0
    with_del_predicate = 0
    with_del_object = 0
    with_add_subject_in_input = 0
    with_add_predicate_in_input = 0
    with_add_object_in_input = 0
    with_del_subject_in_input = 0
    with_del_predicate_in_input = 0
    with_del_object_in_input = 0
    for i in range(len(dataset['predicate'])):
        count += 1
        constraints[dataset['constraint_id'][i]] += 1
        if dataset['subject_predicates'][i]:
            with_subject_desc += 1
            subject_desc_sum += len(dataset['subject_predicates'][i])
            assert len(dataset['subject_predicates'][i]) == len(dataset['subject_objects'][i]) # and len(dataset['subject_predicates'][i]) == len(dataset['subject_objects_text'][i])
        if dataset['object_predicates'][i]:
            with_object_desc += 1
            object_desc_sum += len(dataset['object_predicates'][i])
            assert len(dataset['object_predicates'][i]) == len(dataset['object_objects'][i]) # and len(dataset['object_predicates'][i]) == len(dataset['object_objects_text'][i])
        if encoder.encode('<http://wikiba.se/history/ontology#pageStatusCode>') in dataset['object_predicates'][i]:
            with_object_http_status += 1
        if encoder.encode('<http://wikiba.se/history/ontology#pageContainsLabel>') in dataset['object_predicates'][i]:
            object_http_contains_label += 1
        if dataset['other_subject'][i]:
            with_other_triple += 1
        if dataset['other_entity_predicates'][i]:
            with_other_entity_desc += 1
            other_entity_desc_sum += len(dataset['other_entity_predicates'][i])
            assert len(dataset['other_entity_predicates'][i]) == len(dataset['other_entity_objects'][i]) # and len(dataset['other_entity_predicates'][i]) == len(dataset['other_entity_objects_text'][i])
        if encoder.encode('<http://wikiba.se/history/ontology#pageStatusCode>') in dataset['other_entity_predicates'][i]:
            with_other_entity_http_status += 1
        if encoder.encode('<http://wikiba.se/history/ontology#pageContainsLabel>') in dataset['other_entity_predicates'][i]:
            other_entity_http_contains_label += 1
        if dataset['add_subject'][i]:
            with_add_subject += 1
            if dataset['add_subject'][i] in known_entities:
                with_add_subject_in_input += 1
        if dataset['add_predicate'][i]:
            with_add_predicate += 1
            if dataset['add_predicate'][i] in known_entities:
                with_add_predicate_in_input += 1
        if dataset['add_object'][i]:
            with_add_object += 1
            if dataset['add_object'][i] in known_entities:
                with_add_object_in_input += 1
        if dataset['del_subject'][i]:
            with_del_subject += 1
            if dataset['del_subject'][i] in known_entities:
                with_del_subject_in_input += 1
        if dataset['del_predicate'][i]:
            with_del_predicate += 1
            if dataset['del_predicate'][i] in known_entities:
                with_del_predicate_in_input += 1
        if dataset['del_object'][i]:
            with_del_object += 1
            if dataset['del_object'][i] in known_entities:
                with_del_object_in_input += 1
    print('{} past violations for {} constraints'.format(sum(constraints.values()), len(constraints)))
    print('with subject desc: {} (average length: {})'.format(with_subject_desc / count, subject_desc_sum / with_subject_desc))
    print('with object desc: {} (average length: {})'.format(with_object_desc / count, object_desc_sum / with_object_desc))
    print('with object web page: {} (with label in page: {})'.format(with_object_http_status / count, object_http_contains_label / with_object_http_status if with_object_http_status else '?'))
    print('with other triple: {} ({})'.format(with_other_triple, with_other_triple / count))
    print('with other entity desc: {} (average length: {})'.format(with_other_entity_desc / count, other_entity_desc_sum / with_other_entity_desc if with_other_entity_desc else '?'))
    print('with other entity web page: {} (with label in page: {})'.format(with_other_entity_http_status / count, other_entity_http_contains_label / with_other_entity_http_status if with_other_entity_http_status else '?'))
    print('in input: add subject: {} add predicate: {} add object: {} del subject: {} del predicate: {} del object: {}'.format(with_add_subject_in_input / with_add_subject, with_add_predicate_in_input / with_add_predicate, with_add_object_in_input / with_add_object, with_del_subject_in_input / with_del_subject, with_del_predicate_in_input / with_del_predicate, with_del_object_in_input / with_del_object))
    print('add: {} ({}, subject {} known object {} known)'.format(with_add_subject, with_add_subject / count, with_add_subject_in_input / with_add_subject if with_add_subject else '?', with_add_object_in_input / with_add_object if with_add_object else '?'))
    print('del: {} ({}, subject {} known object {} known)'.format(with_del_subject, with_del_subject / count, with_del_subject_in_input / with_del_subject if with_del_subject else '?', with_del_object_in_input / with_del_object if with_del_object else '?'))
dataset_stats(dev_dataset)

4500 past violations for 649 constraints
with subject desc: 0.992 (average length: 16.012096774193548)
with object desc: 0.798 (average length: 19.10359231411863)
with object web page: 0.1648888888888889 (with label in page: 0.5646900269541779)
with other triple: 1449 (0.322)
with other entity desc: 0.2768888888888889 (average length: 7.06099518459069)
with other entity web page: 0.08355555555555555 (with label in page: 0.5452127659574468)
in input: add subject: 1.0 add predicate: 0.278336686787391 add object: 0.18376928236083165 del subject: 1.0 del predicate: 1.0 del object: 0.9901869158878505
add: 2982 (0.6626666666666666, subject 1.0 known object 0.18376928236083165 known)
del: 2140 (0.47555555555555556, subject 1.0 known object 0.9901869158878505 known)


In [10]:
# train the BERT tokenizer
with open('raw_corrections_text.txt', 'wt') as fp:
    for line in chain(
        train_dataset['object_text'],
        train_dataset['other_object_text']
    ):
        if line:
            fp.write(line)
            fp.write('\n')
tokenizer = BertWordPieceTokenizer(
    clean_text=True,
    handle_chinese_chars=True,
    strip_accents=True,
    lowercase=True,
    pad_token="[PAD]",
    unk_token="[UNK]",
    sep_token="[SEP]",
    cls_token="[CLS]",
    mask_token="[MASK]"
)
tokenizer.train('raw_corrections_text.txt', vocab_size=30000)
os.remove('raw_corrections_text.txt')

In [18]:
# Final preprocessing and model code

import tensorflow as tf
from tensorflow import keras

SEQUENCE_SIZE = 64
tokenizer.enable_padding(length=SEQUENCE_SIZE, pad_token="[PAD]")  # Note 'length' instead of 'max_length'
tokenizer.enable_truncation(max_length=SEQUENCE_SIZE)


def tokenize_sequence(sequence: List[str]) -> np.array:
    matrix = np.zeros((len(sequence), SEQUENCE_SIZE), dtype='int32')
    for i,v in enumerate(tokenizer.encode_batch(list(sequence))):
        matrix[i] = v.ids
    return matrix

class TermEncoder:
    def __init__(self,  min_occurences_count, max_sequence_length):
        self._terms_index = {0: 0}
        self._terms_inverse_index = [0]
        self._terms_count = {}
        self._min_occurences_count = min_occurences_count
        self._max_sequence_length = max_sequence_length

    def fit(self, input: str):
        for value in input:
            if value not in self._terms_index:
                c = self._terms_count.get(value, 0) + 1
                self._terms_count[value] = c
                if c == self._min_occurences_count:
                    self._terms_index[value] = len(self._terms_inverse_index)
                    self._terms_inverse_index.append(value)
                    del self._terms_count[value]

    def fit_sequence(self, sequence: np.array):
        for values in sequence:
            self.fit(values)

    def transform(self, input: np.array) -> np.array:
        return np.array([self._terms_index.get(v, 0) for v in input])

    def transform_sequence(self, input: np.array) -> np.array:
        return keras.preprocessing.sequence.pad_sequences([[self._terms_index.get(v, 0) for v in values] for values in input], maxlen=self._max_sequence_length)

    def decode(self, input: np.array) -> np.array:
        return np.array([self._terms_inverse_index[v] for v in input])

    def save(self, file: str):
        with open(file, 'wt') as fp:
            fp.writelines('{}\n'.format(l) for l in self._terms_inverse_index)

    def __len__(self):
        return len(self._terms_inverse_index)

class DatasetSequence(keras.utils.Sequence):
    def __init__(self, dataset, constraint_id_encoder, predicate_encoder, entity_encoder, output_predicate_encoder, output_entity_encoder, batch_size: int, shuffle: bool = False):
        self.dataset = dataset
        self._constraint_id_encoder = constraint_id_encoder
        self._entity_encoder = entity_encoder
        self._predicate_encoder = predicate_encoder
        self._output_entity_encoder = output_entity_encoder
        self._output_predicate_encoder = output_predicate_encoder
        self.batch_size = batch_size
        self._shuffle = shuffle
        self.indices = np.arange(len(self.dataset['add_subject']))
        self.on_epoch_end()

    def __len__(self):
        return len(self.dataset['add_subject']) // self.batch_size

    def __getitem__(self, idx):
        inds = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        print(f"Indices: {inds}")  # Debug: Check indices
        if not isinstance(inds, np.ndarray):
            inds = np.array(inds)

        # Get batch data
        batch = {
            'constraint_id': np.array(self.dataset['constraint_id'])[inds],
            'constraint_predicates': np.array(self.dataset['constraint_predicates'])[inds],
            'constraint_objects': np.array(self.dataset['constraint_objects'])[inds],
            'subject': np.array(self.dataset['subject'])[inds],
            'subject_predicates': np.array(self.dataset['subject_predicates'])[inds],
            'subject_objects': np.array(self.dataset['subject_objects'])[inds],
            'predicate': np.array(self.dataset['predicate'])[inds],
            'object': np.array(self.dataset['object'])[inds],
            'object_text': np.array(self.dataset['object_text'])[inds],
            'object_predicates': np.array(self.dataset['object_predicates'])[inds],
            'object_objects': np.array(self.dataset['object_objects'])[inds],
            'other_subject': np.array(self.dataset['other_subject'])[inds],
            'other_predicate': np.array(self.dataset['other_predicate'])[inds],
            'other_object': np.array(self.dataset['other_object'])[inds],
            'other_object_text': np.array(self.dataset['other_object_text'])[inds],
            'other_entity_predicates': np.array(self.dataset['other_entity_predicates'])[inds],
            'other_entity_objects': np.array(self.dataset['other_entity_objects'])[inds],
            'add_subject': np.array(self.dataset['add_subject'])[inds],
            'add_predicate': np.array(self.dataset['add_predicate'])[inds],
            'add_object': np.array(self.dataset['add_object'])[inds],
            'del_subject': np.array(self.dataset['del_subject'])[inds],
            'del_predicate': np.array(self.dataset['del_predicate'])[inds],
            'del_object': np.array(self.dataset['del_object'])[inds]
        }

        # Debug: Print batch keys and shapes
        print("Batch keys:", batch.keys())
        for key, value in batch.items():
            print(f"{key}: shape={np.array(value).shape}")

        # Transform the batch data
        x = {
            'constraint_id': self._constraint_id_encoder.transform(batch['constraint_id']),
            'constraint_predicates': self._predicate_encoder.transform_sequence(batch['constraint_predicates']),
            'constraint_objects': self._predicate_encoder.transform_sequence(batch['constraint_objects']),
            'subject': self._entity_encoder.transform(batch['subject']),
            'subject_predicates': self._predicate_encoder.transform_sequence(batch['subject_predicates']),
            'subject_objects': self._entity_encoder.transform_sequence(batch['subject_objects']),
            'predicate': self._predicate_encoder.transform(batch['predicate']),
            'object': self._entity_encoder.transform(batch['object']),
            'object_text': tokenize_sequence(batch['object_text']),
            'object_predicates': self._predicate_encoder.transform_sequence(batch['object_predicates']),
            'object_objects': self._entity_encoder.transform_sequence(batch['object_objects']),
            'other_subject': self._entity_encoder.transform(batch['other_subject']),
            'other_predicate': self._predicate_encoder.transform(batch['other_predicate']),
            'other_object': self._entity_encoder.transform(batch['other_object']),
            'other_object_text': tokenize_sequence(batch['other_object_text']),
            'other_entity_predicates': self._predicate_encoder.transform_sequence(batch['other_entity_predicates']),
            'other_entity_objects': self._entity_encoder.transform_sequence(batch['other_entity_objects'])
        }

        y = {
            'add_subject': self._output_entity_encoder.transform(batch['add_subject']),
            'add_predicate': self._output_predicate_encoder.transform(batch['add_predicate']),
            'add_object': self._output_entity_encoder.transform(batch['add_object']),
            'del_subject': self._output_entity_encoder.transform(batch['del_subject']),
            'del_predicate': self._output_predicate_encoder.transform(batch['del_predicate']),
            'del_object': self._output_entity_encoder.transform(batch['del_object'])
        }

        return x, y

    def on_epoch_end(self):
        gc.collect()
        if self._shuffle:
            print('shuffling dataset')
            np.random.shuffle(self.indices)

class EntityFactsEmbedding:
    def __init__(self, entities_count: int, predicates_count: int, dropout: keras.layers.Dropout):
        self._dropout = dropout
        self._entity_predicates_embedding = keras.layers.Embedding(predicates_count, 128, mask_zero=True, name="predicate_embedding")
        self._entity_objects_embedding = keras.layers.Embedding(entities_count, 128, mask_zero=True, name="entity_embedding")
        self._entity_desc_combination = keras.layers.Concatenate(-1, name="entity_desc_combination")
        self._entity_desc_featurizer = keras.layers.Dense(128, activation='relu', name="entity_desc_featurizer")
        self._entity_desc_attention = tf.keras.layers.Attention(name="entity_desc_attention")
        self.to_sequence = keras.layers.Reshape((1,128), name="entity_desc_to_sequence")
        self._from_sequence = keras.layers.Reshape((128,), name="entity_desc_from_sequence")
        self._query_featurizer = keras.layers.Dense(128, activation='relu', name="entity_desc_query_featurizer")
        self._entity_desc_pool = keras.layers.GlobalMaxPooling1D(name="entity_desc_reduction")

    def __call__(self, predicates, objects, query = None):
        value = self._entity_desc_featurizer(
            self._dropout(
                self._entity_desc_combination([
                    self._entity_predicates_embedding(predicates),
                    self._entity_objects_embedding(objects),
                ])
            )
        )
        if query is None:
            return self._entity_desc_pool(value)
        else:
            return self._from_sequence(self._entity_desc_attention(
                [self.to_sequence(self._query_featurizer(query)), value],
            ))


class Model:
    def __init__(self, train_dataset: Dict[str,np.array], dev_dataset: Dict[str,np.array], epochs: int = 2,  batch_size: int = 32, dropout: float = 0., with_attention: bool = False, with_entity_facts: bool = True, with_literals: bool = True, with_constraint_id: bool = False, with_subject: bool = False):
        embedding_size = 128
        min_occurences = 100
        dropout_layer = keras.layers.Dropout(dropout, name="dropout")

        self._entity_encoder = TermEncoder(min_occurences, SEQUENCE_SIZE)
        self._entity_encoder.fit_sequence(train_dataset['constraint_objects'])
        self._entity_encoder.fit(train_dataset['subject'])
        self._entity_encoder.fit(train_dataset['object'])
        self._entity_encoder.fit(train_dataset['other_subject'])
        self._entity_encoder.fit(train_dataset['other_object'])
        self._entity_encoder.fit_sequence(train_dataset['subject_objects'])
        self._entity_encoder.fit_sequence(train_dataset['object_objects'])
        self._entity_encoder.fit_sequence(train_dataset['other_entity_objects'])

        self._output_entity_encoder = TermEncoder(min_occurences, SEQUENCE_SIZE)
        self._output_entity_encoder.fit(train_dataset['add_subject'])
        self._output_entity_encoder.fit(train_dataset['add_object'])
        self._output_entity_encoder.fit(train_dataset['del_subject'])
        self._output_entity_encoder.fit(train_dataset['del_object'])

        self._predicate_encoder = TermEncoder(min_occurences, SEQUENCE_SIZE)
        self._predicate_encoder.fit_sequence(train_dataset['constraint_predicates'])
        self._predicate_encoder.fit(train_dataset['predicate'])
        self._predicate_encoder.fit(train_dataset['other_predicate'])
        self._predicate_encoder.fit_sequence(train_dataset['subject_predicates'])
        self._predicate_encoder.fit_sequence(train_dataset['object_predicates'])
        self._predicate_encoder.fit_sequence(train_dataset['other_entity_predicates'])

        self._output_predicate_encoder = TermEncoder(min_occurences, SEQUENCE_SIZE)
        self._output_predicate_encoder.fit(train_dataset['add_predicate'])
        self._output_predicate_encoder.fit(train_dataset['del_predicate'])

        self._constraint_id_encoder = TermEncoder(min_occurences, SEQUENCE_SIZE)
        self._constraint_id_encoder.fit(train_dataset['constraint_id'])

        print('Dataset stats: {} input predicates, {} input entities, {} output predicates, {} output entities'.format(len(self._predicate_encoder), len(self._entity_encoder), len(self._output_predicate_encoder), len(self._output_entity_encoder)))

        word_embedding = keras.layers.Embedding(30000, embedding_size, mask_zero=True, name="text_embedding")
        text_pool = keras.layers.GlobalMaxPool1D(name="text_pool")
        #text_pool = keras.layers.Bidirectional(keras.layers.LSTM(128, name="text_lstm"), name="text_pool")
        text_embedding = lambda i: text_pool(word_embedding(i))
        entity_desc_embedding = EntityFactsEmbedding(len(self._entity_encoder), len(self._predicate_encoder), dropout_layer)
        from_seq = keras.layers.Reshape((embedding_size,), name="from_sequence")
        predicate_embedding = lambda i: from_seq(entity_desc_embedding._entity_predicates_embedding(i))
        entity_embedding = lambda i: from_seq(entity_desc_embedding._entity_objects_embedding(i))

        # constraint
        constraint_id_input = keras.Input(shape=(1,), name="constraint_id")
        constraint_predicates_input = keras.Input(shape=(None,), name="constraint_predicates")
        constraint_objects_input = keras.Input(shape=(None,), name="constraint_objects")
        if with_constraint_id:
            constraint_id_embedding = keras.layers.Embedding(len(self._constraint_id_encoder), embedding_size, name="constraint_id_embedding")
            constraint_features = from_seq(constraint_id_embedding(constraint_id_input))
        else:
            constraint_features = entity_desc_embedding(predicates=constraint_predicates_input, objects=constraint_objects_input)

        # violation
        subject_input = keras.Input(shape=(1,), name="subject")
        subject_predicates_input = keras.Input(shape=(None,), name="subject_predicates")
        subject_objects_input = keras.Input(shape=(None,), name="subject_objects")
        predicate_input = keras.Input(shape=(1,), name="predicate")
        object_input = keras.Input(shape=(1,), name="object")
        object_text_input = keras.Input(shape=(None,), name="object_text")
        object_predicates_input = keras.Input(shape=(None,), name="object_predicates")
        object_objects_input = keras.Input(shape=(None,), name="object_objects")
        other_subject_input = keras.Input(shape=(1,), name="other_subject")
        other_predicate_input = keras.Input(shape=(1,), name="other_predicate")
        other_object_input = keras.Input(shape=(1,), name="other_object")
        other_object_text_input = keras.Input(shape=(None,), name="other_object_text")
        other_entity_predicates_input = keras.Input(shape=(None,), name="other_entity_predicates")
        other_entity_objects_input = keras.Input(shape=(None,), name="other_entity_objects")

        subject_features = entity_embedding(subject_input)
        predicate_features = predicate_embedding(predicate_input)
        object_features = entity_embedding(object_input)
        object_text_features = text_embedding(object_text_input)
        other_subject_features = entity_embedding(other_subject_input)
        other_predicate_features = predicate_embedding(other_predicate_input)
        other_object_features = entity_embedding(other_object_input)
        other_object_text_features = text_embedding(other_object_text_input)
        if with_attention:
            subject_query_features = keras.layers.Dense(units=embedding_size, name="subject_query_features")(dropout_layer(constraint_features))
            subject_desc_features = entity_desc_embedding(predicates=subject_predicates_input, objects=subject_objects_input, query=subject_query_features)
        else:
            subject_desc_features = entity_desc_embedding(predicates=subject_predicates_input, objects=subject_objects_input)
        if with_attention:
            object_query_features = keras.layers.Dense(units=embedding_size, name="object_query_features")(dropout_layer(constraint_features))
            object_desc_features = entity_desc_embedding(predicates=object_predicates_input, objects=object_objects_input, query=object_query_features)
        else:
            object_desc_features = entity_desc_embedding(predicates=object_predicates_input, objects=object_objects_input)
        if with_attention:
            other_entity_query_features = keras.layers.Dense(units=embedding_size, name="other_entity_query_features")(dropout_layer(constraint_features))
            other_entity_desc_features = entity_desc_embedding(predicates=other_entity_predicates_input, objects=other_entity_objects_input, query=other_entity_query_features)
        else:
            other_entity_desc_features = entity_desc_embedding(predicates=other_entity_predicates_input, objects=other_entity_objects_input)

        inputs = [
            constraint_features,
            predicate_features,
            object_features,
            other_predicate_features,
            other_object_features
        ]
        if with_subject:
            inputs.append(subject_features)
            inputs.append(other_subject_features)
        if with_entity_facts:
            inputs.append(subject_desc_features)
            inputs.append(object_desc_features)
            inputs.append(other_entity_desc_features)
        if with_literals:
            inputs.append(object_text_features)
            inputs.append(other_object_text_features)
        dense_input = dropout_layer(keras.layers.concatenate(inputs, name="input_concat"))
        dense_l1 = dropout_layer(keras.layers.Dense(units=embedding_size*4, activation='relu', name="dense_1")(dense_input))
        dense_l2 = dropout_layer(keras.layers.Dense(units=embedding_size*4, activation='relu', name="dense_2")(dense_l1))
        add_subject = keras.layers.Dense(units=len(self._output_entity_encoder), activation='softmax', name="add_subject")(dense_l2)
        add_predicate = keras.layers.Dense(units=len(self._output_predicate_encoder), activation='softmax', name="add_predicate")(dense_l2)
        add_object = keras.layers.Dense(units=len(self._output_entity_encoder), activation='softmax', name="add_object")(dense_l2)
        del_subject = keras.layers.Dense(units=len(self._output_entity_encoder), activation='softmax', name="del_subject")(dense_l2)
        del_predicate = keras.layers.Dense(units=len(self._output_predicate_encoder), activation='softmax', name="del_predicate")(dense_l2)
        del_object = keras.layers.Dense(units=len(self._output_entity_encoder), activation='softmax', name="del_object")(dense_l2)

        self._model = keras.Model(inputs=[
            constraint_id_input, constraint_predicates_input,  constraint_objects_input,
            subject_input,
            subject_predicates_input, subject_objects_input,
            predicate_input,
            object_input, object_text_input,
            object_predicates_input, object_objects_input,
            other_subject_input,
            other_predicate_input,
            other_object_input, other_object_text_input,
            other_entity_predicates_input, other_entity_objects_input
        ], outputs=[add_subject, add_predicate, add_object, del_subject, del_predicate, del_object])
        self._model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
        print('model build done')
        print(self._model.summary())

        print('training with {} epochs, a batch size of {} and a dropout rate of {}'.format(epochs, batch_size, dropout))

        best_weights_filepath = './best_weights_corrections.keras'
        early_stopping = keras.callbacks.EarlyStopping(monitor='val_loss', patience=1, verbose=1, mode='auto')
        save_best_model = keras.callbacks.ModelCheckpoint(best_weights_filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')
        history = self._model.fit(self._to_dataset(train_dataset, batch_size=batch_size, shuffle=True), epochs=epochs, validation_data=self._to_dataset(dev_dataset, batch_size=batch_size), callbacks=[early_stopping, save_best_model])
        self._model.load_weights(best_weights_filepath)

    def _to_dataset(self, dataset: Dict[str,np.array], batch_size, shuffle: bool = False):
        return DatasetSequence(dataset, self._constraint_id_encoder, self._predicate_encoder, self._entity_encoder, self._output_predicate_encoder, self._output_entity_encoder, batch_size=batch_size, shuffle=shuffle)

    def eval(self, dataset: Dict[str,np.array]):
        ok_by_constraint = defaultdict(int)
        error_by_constraint = defaultdict(int)
        total_by_constraint = defaultdict(int)
        ok = 0
        error = 0
        total = 0
        parameters_found_and_correct = defaultdict(int)
        parameters_predicted = defaultdict(int)
        parameters_expected = defaultdict(int)

        for add_subject_pred, add_predicate_pred, add_object_pred, del_subject_pred, del_predicate_pred, del_object_pred in model.predict(dataset):
            constraint = dataset['constraint_id'][total]
            predictions = {
                'add_subject': add_subject_pred,
                'add_predicate': add_predicate_pred,
                'add_object': add_object_pred,
                'del_subject': del_subject_pred,
                'del_predicate': del_predicate_pred,
                'del_object': del_object_pred,
            }
            if predictions['add_subject'] != 0 or predictions['add_predicate'] != 0 or predictions['add_object'] != 0:
                if predictions['add_subject'] == 0 or predictions['add_predicate'] == 0 or predictions['add_object'] == 0:
                    total += 1
                    total_by_constraint[constraint] += 1
                    continue # Missing value
            if predictions['del_subject'] != 0 or predictions['del_predicate'] != 0 or predictions['del_object'] != 0:
                if predictions['del_subject'] == 0 or predictions['del_predicate'] == 0 or predictions['del_object'] == 0:
                    total += 1
                    total_by_constraint[constraint] += 1
                    continue # Missing value

            if all(dataset[k][total] == v for k, v in predictions.items()):
                ok += 1
                ok_by_constraint[constraint] += 1
            else:
                error += 1
                error_by_constraint[constraint] += 1
            for k, v in predictions.items():
                if v is not None:
                    parameters_predicted[k] += 1
                    if v == dataset[k][total] :
                        parameters_found_and_correct[k] += 1
                if dataset[k][total] is not None:
                    parameters_expected[k] += 1
            total_by_constraint[constraint] += 1
            total += 1

        by_constraint = [self._precision_recall(ok_by_constraint[c], ok_by_constraint[c] + error_by_constraint[c], total_by_constraint[c]) for c in total_by_constraint.keys()]
        return {
            **self._precision_recall(ok, ok+error, total),
            'accuracy': ok/total,
            'parameters': {k: self._precision_recall(parameters_found_and_correct[k], parameters_predicted[k],v) for k,v in parameters_expected.items()},
            'by_constraint': by_constraint,
            'ok': ok,
            'error': error,
            'total': total
        }

    @staticmethod
    def _precision_recall(found_and_correct: int, predicted: int, expected: int) -> dict:
        precision = found_and_correct / predicted if predicted else float('nan')
        recall = found_and_correct / expected if expected else float('nan') # TODO: should be found and correct ??? This seems badly wrong and it's what we did in the CorHist paper
        F1 = 2 * precision*recall / (precision+recall) if precision + recall else float('nan')
        return {
            'precision': precision,
            'recall': recall,
            'F1': F1
        }

    def predict(self, dataset: Dict[str,np.array]):
        dataset = self._to_dataset(dataset, batch_size=128)
        for i in range(len(dataset)):
            add_subject_pred, add_predicate_pred, add_object_pred, del_subject_pred, del_predicate_pred, del_object_pred = self._model.predict(dataset[i][0])
            add_subject_pred = self._output_entity_encoder.decode(np.argmax(add_subject_pred, 1))
            add_predicate_pred = self._output_predicate_encoder.decode(np.argmax(add_predicate_pred, 1))
            add_object_pred = self._output_entity_encoder.decode(np.argmax(add_object_pred, 1))
            del_subject_pred = self._output_entity_encoder.decode(np.argmax(del_subject_pred, 1))
            del_predicate_pred = self._output_predicate_encoder.decode(np.argmax(del_predicate_pred, 1))
            del_object_pred = self._output_entity_encoder.decode(np.argmax(del_object_pred, 1))
            yield from zip(add_subject_pred, add_predicate_pred, add_object_pred, del_subject_pred, del_predicate_pred, del_object_pred)

    def save(self, dir: str):
        os.makedirs(dir, exist_ok=True)
        self._entity_encoder.save(dir + '/entity_encoding.txt')
        self._output_entity_encoder.save(dir + '/output_entity_encoding.txt')
        self._predicate_encoder.save(dir + '/predicate_encoding.txt')
        self._output_predicate_encoder.save(dir + '/output_predicate_encoding.txt')
        self._constraint_id_encoder.save(dir + '/constraint_id_encoding.txt')
        self._model.save(dir + '/model')

In [19]:
model = None
# hack to clear memory before creating the new model
gc.collect()
model = Model(train_dataset, dev_dataset, epochs=20, batch_size=256, dropout=0.1,
              with_attention=False, with_entity_facts=True, with_literals=True, with_constraint_id=False, with_subject=False)
gc.collect()

Dataset stats: 238 input predicates, 330 input entities, 6 output predicates, 8 output entities
model build done


None
training with 20 epochs, a batch size of 256 and a dropout rate of 0.1
shuffling dataset
Indices: [2446 1019 2129 3822 1168 2754 2249 4401  866 3334 2160 2929 3005 2341
 1417 4139 4160 1533 3897  220 3021 4174 4295 3187  453   96 1710  533
 4110 1364 2006  896  991 4498  753 3088  474 3595 2779 1205 1897 1891
  525 2246 1938  795 4116 4196 3043 4007 1119  404 2081 2050  711 3602
 3415 1173 1402 1460 2128  301 4330   41 4338 1662 4372 3520  971 1287
 1563 1440 3920 1032 3410  765 2758  396 1286  329 3186 3381 2313 1656
  542 3391 4419 1077 3740 2883 2481  240 3793  512 1060 2922  651 1876
 4215  541 1340 2672 4314 1846 1484 2232 2790 3367 2846 1112 1040 3609
 3877 4362 2838 2409 1596 2970 2679 1705  478 3659  493 1885 3805 2323
 1886 3518 4415  235 2507 4038 1799 3741  881 3825  204 2146 3790 2334
   26 2184 2868  143 1282 1715  199  728  946 3763  446 3516 3127 3663
  975 2887 1436  151 3279 3158 3539 1039  298 2188 2338 1450  771 1661
 2984 1022 4089  331 2432 3513 1155 3575 1355

ValueError: For a model with multiple outputs, when providing the `metrics` argument as a list, it should have as many entries as the model has outputs. Received:
metrics=['sparse_categorical_accuracy']
of length 1 whereas the model has 6 outputs.