In [7]:
import os
import xml.etree.ElementTree as ET
from transformers import AutoTokenizer, AutoModelForTokenClassification
from torch.utils.data import Dataset, DataLoader
import torch
import xml.etree.ElementTree as ET

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
def parse_tcf(xml_file):
    # Parse the XML file
    tree = ET.parse(xml_file)
    root = tree.getroot()

    # Handling namespaces in the XML
    namespaces = {
        'tc': 'http://www.dspin.de/data/textcorpus'
    }

    # Extract tokens
    tokens = {token.get('ID'): token.text for token in root.findall('.//tc:tokens/tc:token', namespaces)}

    # Extract coreference information
    coreferences = []
    for entity in root.findall('.//tc:references/tc:entity', namespaces):
        # Each entity may contain multiple references
        entity_mentions = []
        for reference in entity.findall('.//tc:reference', namespaces):
            ref_id = reference.get('ID')
            token_ids = reference.get('tokenIDs').split()
            mention_text = [tokens[token_id] for token_id in token_ids]
            entity_mentions.append((ref_id, mention_text, token_ids))
        coreferences.append(entity_mentions)

    return tokens, coreferences

# Usage
tokens, coreferences = parse_tcf('/home/tico/Desktop/master_classes/NLP/coref149_v1_b/coref149_v1/ssj4.15.tcf')


print("\nCoreferences:")
for entity in coreferences:
    print(entity)


Coreferences:
[('rc_0', ['televiziji', 'BBC'], ['t_80', 't_81'])]
[('rc_1', ['Liverpoola'], ['t_7'])]
[('rc_2', ['papeža', 'Janeza', 'Pavla', 'II.'], ['t_38', 't_39', 't_40', 't_41']), ('rc_3', ['Janeza', 'Pavla', 'II.'], ['t_39', 't_40', 't_41'])]
[('rc_4', ['princa', 'Charlesa'], ['t_58', 't_59']), ('rc_5', ['Charlesa'], ['t_59'])]
[('rc_6', ['Camille', 'Parker', 'Bowles'], ['t_61', 't_62', 't_63'])]
[('rc_7', ['Angliji'], ['t_100'])]
[('rc_8', ['dirka', 'Aintree', '2005'], ['t_2', 't_3', 't_4']), ('rc_9', ['Aintree', '2005'], ['t_3', 't_4']), ('rc_10', ['osrednje', 'dirke'], ['t_31', 't_32']), ('rc_11', ['dirka'], ['t_68']), ('rc_12', ['dirke', 'Grand', 'National'], ['t_108', 't_109', 't_110']), ('rc_13', ['dirko'], ['t_124'])]


In [16]:
def parse_tcf(xml_file):
    tree = ET.parse(xml_file)
    root = tree.getroot()
    namespaces = {'tc': 'http://www.dspin.de/data/textcorpus'}

    tokens = {token.get('ID'): token.text for token in root.findall('.//tc:tokens/tc:token', namespaces)}
    coreferences = []
    unique_coref_ids = set()  # Set to track unique coreference IDs

    for entity in root.findall('.//tc:references/tc:entity', namespaces):
        entity_mentions = []
        for reference in entity.findall('.//tc:reference', namespaces):
            ref_id = reference.get('ID')
            token_ids = reference.get('tokenIDs').split()
            entity_mentions.append((ref_id, token_ids))
            unique_coref_ids.add(ref_id)  # Add coreference ID to the set
        coreferences.append(entity_mentions)

    return tokens, coreferences, unique_coref_ids
def preprocess_data(tokens, coreferences,tokenizer):
    text = ' '.join(tokens.values())
    encoded = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
    # Prepare labels based on coreference data, initially set all labels to -1 (ignore)
    labels = [-1] * len(encoded['input_ids'][0])
    for coref in coreferences:
        for ref_id, token_ids in coref:
            for token_id in token_ids:
                if token_id in tokens:
                    token_index = list(tokens.keys()).index(token_id)
                    labels[token_index] = ref_id
    encoded['labels'] = labels
    return encoded
class CorefDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}
    
    def __len__(self):
        return len(self.encodings['input_ids'])
def load_dataset(directory,tokenizer):
    all_unique_coref_ids = set()
    dataset_encodings = {'input_ids': [], 'attention_mask': [], 'token_type_ids': [], 'labels': []}
    for filename in os.listdir(directory):
        if filename.endswith('.tcf'):
            tokens, coreferences, file_unique_ids  = parse_tcf(os.path.join(directory, filename))
            encodings = preprocess_data(tokens, coreferences,tokenizer)
            all_unique_coref_ids.update(file_unique_ids)
            for key in dataset_encodings:
                dataset_encodings[key].append(encodings[key][0])
    number_of_unique_coref_ids = len(all_unique_coref_ids)
    return dataset_encodings,number_of_unique_coref_ids

In [18]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
dataset_encodings = load_dataset('/home/tico/Desktop/master_classes/NLP/coref149_v1_b/coref149_v1',tokenizer)
dataset = CorefDataset(dataset_encodings)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

TypeError: tuple indices must be integers or slices, not str

In [None]:
# Example usage
all_unique_coref_ids = set()
for filename in os.listdir('path_to_tcf_directory'):
    if filename.endswith('.tcf'):
        _, _, file_unique_ids = parse_tcf(os.path.join('path_to_tcf_directory', filename))
        all_unique_coref_ids.update(file_unique_ids)

number_of_unique_coref_ids = len(all_unique_coref_ids)

In [None]:

dataset_encodings = load_dataset('path_to_tcf_directory')
dataset = CorefDataset(dataset_encodings)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)