In [25]:
import pandas as pd
import os

In [26]:
train_tsv_dir = "dataset/train/boxes_transcripts_labels"
val_tsv_dir = "dataset/val/boxes_transcripts"
val_ann_tsv_dir = "dataset/val_w_ann/boxes_transcripts_labels"

In [27]:
def load_and_parse_tsv_files(directory):
    all_data = []
    for file in os.listdir(directory):
        if file.endswith('.tsv'):
            file_path = os.path.join(directory, file)
            df = pd.read_csv(file_path, sep=',', header=None)
            # print(df)
            if df.shape[1] == 8:
                df.columns = ['start_index', 'end_index', 'x_tl', 'y_tl', 'x_br', 'y_br', 'transcript', 'label']
            elif df.shape[1] == 7:
                df.columns = ['start_index', 'end_index', 'x_tl', 'y_tl', 'x_br', 'y_br', 'transcript']
            all_data.append(df)
    return pd.concat(all_data, ignore_index=True)

In [28]:
train_df = load_and_parse_tsv_files(train_tsv_dir)
# val_df = load_and_parse_tsv_files(val_tsv_dir)

In [29]:
print(train_df[:5])

   start_index  end_index  x_tl  y_tl  x_br  y_br  transcript  label
0           33         33   215     4   227    21           a  OTHER
1           35         44   235     3   308    21  Employee's  OTHER
2           46         51   311     3   349    20      social  OTHER
3           53         60   352     3   401    20    security  OTHER
4           62         67   404     3   457    21      number  OTHER


In [30]:
import spacy
from spacy.training import Example
import pandas as pd

In [2]:
def remove_overlapping_entities(entities):
    """Remove overlapping entities from the list."""
    if not entities:
        return entities
    
    # Sort entities by start index
    entities = sorted(entities, key=lambda x: x[0])
    
    non_overlapping_entities = []
    last_end = -1
    
    for start, end, label in entities:
        if start > last_end:
            non_overlapping_entities.append((start, end, label))
            last_end = end
    
    return non_overlapping_entities


In [31]:
import pandas as pd
import os
from tqdm import tqdm
import spacy
from spacy.tokens import DocBin

In [11]:
# def convert_to_spacy_format(df):
#     df['transcript'] = df['transcript'].astype(str)
    
#     texts = []
#     annotations = []
    
#     grouped = df.groupby('transcript')
    
#     for text, group in grouped:
#         text_entities = []
#         for _, row in group.iterrows():
#             if pd.notna(row['start_index']) and pd.notna(row['end_index']) and pd.notna(row['label']):
#                 text_entities.append((int(row['start_index']), int(row['end_index']), row['label']))
        
#         # Remove overlapping entities
#         text_entities = remove_overlapping_entities(text_entities)
        
#         if text_entities:
#             texts.append(text)
#             annotations.append({'entities': text_entities})
    
#     return list(zip(texts, annotations))
def convert_to_spacy_format(df):
    df['transcript'] = df['transcript'].astype(str)
    texts = df['transcript'].tolist()
    entities = []
    for _, row in df.iterrows():
        print(row['transcript'])
        entities.append({
            'entities': [
                (0, len(row['transcript'])-1, row['label'])
            ]
        })
    return list(zip(texts, entities))

In [None]:
# train_df = clean_dataframe(train_df)
spacy_train_data = convert_to_spacy_format(train_df)
# spacy_val_data = convert_to_spacy_format(val_df)

In [None]:
# Print a sample of spacy_train_data
for text, annotations in spacy_train_data:
    # Filter entities where the label is not 'OTHER'
    non_other_entities = [ent for ent in annotations['entities'] if ent[2] != 'OTHER']
    
    if non_other_entities:
        print(f"Text: {text}")
        print(f"Entities: {non_other_entities}")
        print("-" * 40)


In [None]:
nlp = spacy.load("en_core_web_sm")  # Create a new, blank model
# ner = nlp.add_pipe("ner")

In [107]:
import random

In [None]:
print(spacy_train_data[int(random.random()):2])

In [15]:
import json

# Define the path to your existing JSON file
json_file_path = "./spacy_train_data.json"
cleaned_json_file_path = "./cleaned_spacy_train_data.json"

# Read the JSON file with the incorrect encoding
with open(json_file_path, 'r', encoding='utf-8') as f:
    data = f.read()

# Replace unwanted characters
data = data.replace('Â', '')  # Remove the Â character

# Load the JSON data
data = json.loads(data)

# Write the cleaned JSON data back to a new file
with open(cleaned_json_file_path, 'w', encoding='utf-8') as f:
    json.dump(data, f, ensure_ascii=False, indent=4)


In [16]:
import json

# Load the cleaned JSON file to verify
with open(cleaned_json_file_path, 'r', encoding='utf-8') as f:
    cleaned_data = json.load(f)

# print(cleaned_data)  # Print to verify


In [None]:
db = DocBin() # create a DocBin object
# print(spacy_train_data[:2])
for text, annot in tqdm(cleaned_data): # data in previous format
    doc = nlp.make_doc(text) # create doc object from text
    ents = []
    # print(type(doc))
    for start, end, label in annot["entities"]: # add character indexes
        # print(start,end,label,doc)
        # print(doc)
        span = doc.char_span(start,end,label=label,alignment_mode='expand')
        if span is None:
            print("Skipping entity")
        else:
            print('made it:',start,end,text,label)
            ents.append(span)
    doc.ents = ents # label the text with the ents
    db.add(doc)
# os.chdir(r'XXXX\XXXXX')
db.to_disk("train.spacy")

In [49]:
all_labels = set()
for _, annotations in spacy_train_data:
    for start, end, label in annotations['entities']:
        all_labels.add(label)
for label in all_labels:
    ner.add_label(label)

In [45]:
import random

In [51]:
print(len(spacy_train_data))

20183


In [None]:
from spacy.training import Example

def train_spacy_model(nlp, train_data, batch_size=1000, n_epochs=50):
    """Train SpaCy model with batch processing."""
    optimizer = nlp.begin_training()

    for epoch in range(n_epochs):
        losses = {}
        # Shuffle the training data at the start of each epoch
        random.shuffle(train_data)
        
        # Process the data in batches
        for i in range(0, len(train_data), batch_size):
            batch = train_data[i:i+batch_size]
            examples = [Example.from_dict(nlp.make_doc(text), annotations) for text, annotations in batch]
            # Update the model with the current batch
            nlp.update(examples, drop=0.5, losses=losses)

        # Print losses for the current epoch
        print(f"Epoch {epoch}, Losses: {losses}")

# Assuming you have already defined `spacy_train_data`
train_spacy_model(nlp, spacy_train_data, batch_size=1000, n_epochs=3)


In [None]:
print(len(train_df))

In [14]:
nlp.to_disk("Runs/SCAPY3")

In [2]:
import spacy

In [34]:
model1 = spacy.load("output/model-best")

In [32]:
import os
import pandas as pd
import spacy

def add_predictions_to_tsv_single(model, input_file, output_file):
    # Load the TSV file into a DataFrame
    df = pd.read_csv(input_file, sep=',', header=None)
    df.columns = ['start_index', 'end_index', 'x_tl', 'y_tl', 'x_br', 'y_b', 'transcript']
    # Initialize a list to store predicted labels
    predicted_labels = []
    # Iterate through each transcript and predict labels
    for transcript in df['transcript']:
        if isinstance(transcript, str) and transcript.strip():
            # print(transcript)
            doc = model(transcript)  # Use the loaded model to predict entities
            if doc.ents:
                # Get the label of the first predicted entity (if any) for simplicity
                predicted_label = doc.ents[0].label_
            else:
                predicted_label = "OTHER"  # Fallback to "OTHER" if no entity is predicted
        else:
            predicted_label = "OTHER"  # Handle empty or invalid transcripts

        predicted_labels.append(predicted_label)
    
    # Add the predicted labels as a new column to the DataFrame
    df['predicted_label'] = predicted_labels
    
    # Save the updated DataFrame to the specified output TSV file
    df.to_csv(output_file, sep='\t', index=False, header=False)
    print(f"Updated file with predictions saved to {output_file}")

# Example usage
add_predictions_to_tsv_single(model1, "dataset/val/boxes_transcripts/7ca4a06b-83af-43cb-b737-3b14af27c894_document-1_page-1.tsv", "predictions.tsv")


Updated file with predictions saved to predictions.tsv


In [35]:
def add_predictions_to_tsv(model, tsv_file):
    for file_name in os.listdir(tsv_file):
        if file_name.endswith('.tsv'):
            file_path = os.path.join(tsv_file, file_name)
            df = pd.read_csv(file_path, sep=',', header=None)
            df.columns = ['start_index', 'end_index', 'x_tl', 'y_tl', 'x_br', 'y_b', 'transcript']
            predicted_labels = []
            for transcript in df['transcript']:
                if isinstance(transcript, str) and transcript.strip():
                    doc = model(transcript)  # Use the loaded model to predict entities
                    print(doc.ents)
                    if doc.ents:
                        # Get the label of the first predicted entity (if any) for simplicity
                        predicted_label = doc.ents[0].label_
                    else:
                        predicted_label = "OTHER"  # Fallback to "OTHER" if no entity is predicted
                else:
                    predicted_label = "OTHER"  # Handle empty or invalid transcripts

                predicted_labels.append(predicted_label)
            
            # Add the predicted labels as a new column to the DataFrame
            df['predicted_label'] = predicted_labels
            
            # Save the updated DataFrame to the same TSV file
            df.to_csv(file_path, sep=',', index=False, header=False)
            print(f"Updated {file_name} with predictions.")

add_predictions_to_tsv(model1, "dataset/val/boxes_transcripts")

()
(Employee's,)
(social,)
(security,)
(number,)
(Safe,)
(Accurate,)
(fe,)
()
(file,)
(Visit,)
(the,)
(IRS,)
(Website,)
(634-61-7592,)
(OMB,)
(No,)
(1545-0008,)
(FAST,)
(Use,)
(at,)
(www.irs,)
(govefile,)
()
(Employer,)
(identification,)
(number,)
((EIN,)
(Wages,)
(tips,)
(other,)
(compensation,)
()
(Federal,)
(income,)
(tax,)
(withheld,)
(80-3888506,)
(180450,)
()
(56,)
(30700,)
()
(05,)
(Social,)
(security,)
(wages,)
(Social,)
(security,)
(tax,)
(withheld,)
()
(Employer's,)
(name,)
(address,)
(and,)
(ZIP,)
(code,)
(205649,)
(34,)
(15732,)
()
(17,)
(Arnold,)
(Wood,)
(and,)
(Rivera,)
(LLC,)
(Medicare,)
(wages,)
(and,)
(tips,)
()
(Medicare,)
(tax,)
(withheld,)
(250,)
(Martinez,)
(Causeway,)
(Suite,)
(141,)
(223895,)
()
(04,)
(6492,)
()
(96,)
(Kelseyfort,)
(AK,)
(91543-2875,)
(Social,)
(security,)
(tips,)
()
(Allocated,)
(tips,)
(205649,)
(34,)
(223895,)
(04,)
()
(Control,)
(number,)
()
(Verification,)
(Code,)
(10,)
(Dependent,)
(care,)
(benefits,)
(170,)
(776585,)
(2a,)
(See,)
(instruct