In [17]:
from transformers import BertTokenizer, BertForTokenClassification, BertForSequenceClassification , BertTokenizerFast
import torch
import json

In [18]:
# Load the NER model and tokenizer
ner_model = BertForTokenClassification.from_pretrained("bert_ner")
ner_tokenizer = BertTokenizerFast.from_pretrained("bert_ner")

# Load the text classification model and tokenizer
clf_model = BertForSequenceClassification.from_pretrained("bert_clf")
clf_tokenizer = BertTokenizer.from_pretrained("bert_clf")

In [19]:
label_map = {
    0: "O",             # Outside any entity
    1: "ANAT",          # Anatomical entity
    2: "OBS-ABSENT",    # Observation absent
    3: "OBS-PRESENT",   # Observation present
    4: "OBS-UNCERTAIN", # Observation uncertain
    5: "PAD"            # Padding token
}

In [23]:
def classify_text(text, model, tokenizer):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).item()
    predictions = predictions + 1
    return str(predictions)  # Convert prediction to string


def predict_entities(text, model, tokenizer, label_map, merge_threshold=4):
    """
    Predict entities and merge contiguous entities with the same label if they are within a certain threshold.
    
    Args:
        text (str): The input text to process.
        model: The NER model.
        tokenizer: The tokenizer for the NER model.
        label_map (dict): A mapping from label IDs to label names.
        merge_threshold (int): The maximum gap between entities to consider them as the same entity.
    
    Returns:
        List[List]: A list of merged entities.
    """
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, return_offsets_mapping=True)
    offset_mapping = inputs.pop("offset_mapping")[0].tolist()
    
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    entities = []
    current_start = None
    current_label = None
    for idx, pred in enumerate(predictions):
        label = label_map.get(pred, "O")
        start, end = offset_mapping[idx]
        
        if label == 'O' or label == 'PAD':
            if current_label and current_label != 'O':
                # End of an entity
                entities.append([current_start, offset_mapping[idx-1][1], current_label])
                current_start = None
                current_label = None
        else:
            if label == current_label and current_start is not None:
                # Extend the current entity
                end = offset_mapping[idx][1]
            else:
                # Start a new entity
                if current_label and current_label != 'O' and current_start is not None:
                    entities.append([current_start, offset_mapping[idx-1][1], current_label])
                current_start = start
                current_label = label
    
    # Add the last entity if applicable
    if current_label and current_label != 'O' and current_start is not None:
        entities.append([current_start, offset_mapping[-1][1], current_label])
    
    # Merge entities that are close to each other
    merged_entities = []
    if entities:
        merged_start, merged_end, merged_label = entities[0]
        
        for start, end, label in entities[1:]:
            if label == merged_label and start <= merged_end + merge_threshold:
                # Extend the current merged entity
                merged_end = max(merged_end, end)
            else:
                # Add the current merged entity to the list
                merged_entities.append([merged_start, merged_end, merged_label])
                merged_start, merged_end, merged_label = start, end, label
        
        # Add the last merged entity
        merged_entities.append([merged_start, merged_end, merged_label])
    
    return merged_entities


In [33]:
def predict_entities(text, model, tokenizer, label_map, merge_threshold=4):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, return_offsets_mapping=True)
    offset_mapping = inputs.pop("offset_mapping")[0].tolist()
    
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1).squeeze().tolist()
    
    entities = []
    current_start = None
    current_label = None
    for idx, pred in enumerate(predictions):
        label = label_map.get(pred, "O")
        start, end = offset_mapping[idx]
        
        if label == 'O' or label == 'PAD':
            if current_label and current_label != 'O':
                # End of an entity
                entities.append([current_start, offset_mapping[idx-1][1], current_label])
                current_start = None
                current_label = None
        else:
            if label == current_label and current_start is not None:
                # Extend the current entity
                end = offset_mapping[idx][1]
            else:
                # Start a new entity
                if current_label and current_label != 'O' and current_start is not None:
                    entities.append([current_start, offset_mapping[idx-1][1], current_label])
                current_start = start
                current_label = label
    
    # Add the last entity if applicable
    if current_label and current_label != 'O' and current_start is not None:
        entities.append([current_start, offset_mapping[-1][1], current_label])
    
    # Merge entities that are close to each other
    merged_entities = []
    if entities:
        merged_start, merged_end, merged_label = entities[0]
        
        for start, end, label in entities[1:]:
            if label == merged_label and start <= merged_end + merge_threshold:
                # Extend the current merged entity
                merged_end = max(merged_end, end)
            else:
                # Add the current merged entity to the list
                merged_entities.append([merged_start, merged_end, merged_label])
                merged_start, merged_end, merged_label = start, end, label
        
        # Add the last merged entity
        merged_entities.append([merged_start, merged_end, merged_label])
    
    return merged_entities


In [34]:
def process_json_data(data, clf_model, clf_tokenizer, ner_model, ner_tokenizer, label_map):
    for item in data['tahminler']:
        text = item['text']

        # Predict the class using the classification model
        predicted_class = classify_text(text, clf_model, clf_tokenizer)
        item['cats'].append(predicted_class)

        # Predict the entities using the NER model
        predicted_entities = predict_entities(text, ner_model, ner_tokenizer, label_map)

        # Format the predicted entities as [start, end, label]
        formatted_entities = [
            [entity[0], entity[1], entity[2]] for entity in predicted_entities
        ]

        # Replace the entities with the formatted entities
        item['entities'] = formatted_entities

    return data


In [31]:
def process_json_data(data, clf_model, clf_tokenizer, ner_model, ner_tokenizer, label_map):
    for item in data['tahminler']:
        text = item['text']

        # Predict the class using the classification model
        predicted_class = classify_text(text, clf_model, clf_tokenizer)
        item['cats'].append(predicted_class)

        # Predict the entities using the NER model
        predicted_entities = predict_entities(text, ner_model, ner_tokenizer, label_map)

        # Format the predicted entities
        formatted_entities = [
            [entity[0], entity[1], entity[2]] for entity in predicted_entities
        ]

        # Extend the 'entities' list with formatted entities
        item['entities'].extend(formatted_entities)

    return data


In [35]:
import json

# Open and read the JSON file
with open("nlp_test_dataset.json", "r", encoding="utf-8") as file:
    data = json.load(file)  # Load the JSON content into a Python dictionary

# Process the data using both models
processed_data = process_json_data(data, clf_model, clf_tokenizer, ner_model, ner_tokenizer, label_map)

# Save the processed data to a new JSON file
with open("293962_LayerLords_2.json", "w", encoding="utf-8") as outfile:
    json.dump(processed_data, outfile, ensure_ascii=False, indent=4)

print("Processed data has been saved to 'processed_data.json'")


Processed data has been saved to 'processed_data.json'
