In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForTokenClassification
from docx import Document
from PyPDF2 import PdfReader
import csv
import re
from io import StringIO, BytesIO
import base64
from ipywidgets import FileUpload, Button, Output, Text, VBox
from IPython.display import display, HTML

# Initialize model
model_name = "iiiorg/piiranha-v1-detect-personal-information"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Storage for manual entries
names_to_hide = set()
companies_to_hide = set()

def split_into_chunks(text, max_tokens=150):
    """Split text into smaller chunks"""
    lines = text.split('\n')
    chunks = []
    current_chunk = ""
    current_tokens = 0

    for line in lines:
        sentences = re.split(r'(?<=[.!?])\s+', line)
        for sentence in sentences:
            if not sentence.strip():  # Skip empty sentences
                continue
            sentence_tokens = sentence.split()
            token_count = len(sentence_tokens)

            if current_tokens + token_count > max_tokens:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                current_chunk = sentence + "\n"
                current_tokens = token_count
            else:
                current_chunk += sentence + " "
                current_tokens += token_count

        if current_chunk:  # Add line break between lines
            current_chunk += "\n"
            current_tokens += 1

    if current_chunk:  # Add the last chunk
        chunks.append(current_chunk.strip())

    return chunks

def get_text_from_file(file_content, file_type):
    if file_type == 'docx':
        doc = Document(BytesIO(file_content))
        return "\n".join(p.text for p in doc.paragraphs if p.text)
    elif file_type == 'pdf':
        pdf = PdfReader(BytesIO(file_content))
        return "\n".join(page.extract_text() for page in pdf.pages)

def add_to_hide(entry, entry_type):
    if entry_type == 'name':
        names_to_hide.add(entry.lower())
    elif entry_type == 'company':
        companies_to_hide.add(entry.lower())

def anonymize_text(text):
    try:
        # Get model predictions
        inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        with torch.no_grad():
            outputs = model(**inputs)
        predictions = torch.argmax(outputs.logits, dim=-1)

        # Process predictions
        tokens = tokenizer.encode_plus(text, return_offsets_mapping=True)
        offset_mapping = tokens['offset_mapping']

        result = list(text)
        found_items = []

        # Handle model predictions
        for idx, (start, end) in enumerate(offset_mapping):
            if idx >= len(predictions[0]):
                break
            if start == end:
                continue

            label_id = predictions[0][idx].item()
            if label_id != model.config.label2id['O']:
                label = model.config.id2label[label_id]
                found_items.append((label, text[start:end]))
                for i in range(start, end):
                    result[i] = ''
                result[start] = f'[{label}]'

        # Handle manual entries
        for name in names_to_hide:
            pos = text.lower().find(name)
            while pos != -1:
                end = pos + len(name)
                found_items.append(('I-GIVENNAME', name))
                for i in range(pos, end):
                    result[i] = ''
                result[pos] = '[I-GIVENNAME]'
                pos = text.lower().find(name, end)

        for company in companies_to_hide:
            pos = text.lower().find(company)
            while pos != -1:
                end = pos + len(company)
                found_items.append(('I-ORG', company))
                for i in range(pos, end):
                    result[i] = ''
                result[pos] = '[I-ORG]'
                pos = text.lower().find(company, end)

        return ''.join(result), found_items
    except Exception as e:
        print(f"Error: {e}")
        return text, []

def create_csv(chunks_data):
    output = StringIO()
    writer = csv.writer(output, quoting=csv.QUOTE_ALL)
    writer.writerow(['Chunk Number', 'Original Text', 'Anonymized Text', 'Found Items'])

    for i, data in enumerate(chunks_data, 1):
        writer.writerow([
            f"Chunk {i}",
            data['original'],
            data['anonymized'],
            data['found_items']
        ])

    csv_string = output.getvalue()
    b64 = base64.b64encode(csv_string.encode()).decode()
    return f'<a href="data:text/csv;base64,{b64}" download="anonymized_text.csv">Download CSV</a>'

# Setup Jupyter widgets
upload = FileUpload(accept='.docx,.pdf', multiple=False)
process_button = Button(description="Process and Anonymize")
output = Output()

name_input = Text(description='Name:')
company_input = Text(description='Company:')
add_name_button = Button(description="Add Name")
add_company_button = Button(description="Add Company")

def on_add_name_clicked(b):
    add_to_hide(name_input.value, 'name')
    name_input.value = ''

def on_add_company_clicked(b):
    add_to_hide(company_input.value, 'company')
    company_input.value = ''

def on_process_clicked(b):
    with output:
        output.clear_output()
        if not upload.value:
            print("Please upload a file first.")
            return

        try:
            file_info = upload.value[0]
            content = file_info['content']
            filename = file_info['name']

            # Get file type and process
            if filename.endswith('.docx'):
                text = get_text_from_file(content, 'docx')
            elif filename.endswith('.pdf'):
                text = get_text_from_file(content, 'pdf')
            else:
                raise ValueError("Please upload a .docx or .pdf file")

            # Split text into chunks
            chunks = split_into_chunks(text)
            print(f"Processing {len(chunks)} chunks...")

            # Process each chunk
            chunks_data = []
            all_found = []

            for i, chunk in enumerate(chunks, 1):
                anonymized, found = anonymize_text(chunk)
                all_found.extend(found)

                chunks_data.append({
                    'original': chunk,
                    'anonymized': anonymized,
                    'found_items': ', '.join(f"{label}: {item}" for label, item in found)
                })

                # Display results for this chunk
                display(HTML(f"<h3>Chunk {i}:</h3>"))
                display(HTML(f"<p><strong>Original:</strong><br>{chunk.replace(chr(10), '<br>')}</p>"))
                display(HTML(f"<p><strong>Anonymized:</strong><br>{anonymized.replace(chr(10), '<br>')}</p>"))
                if found:
                    display(HTML(f"<p><strong>Found Items:</strong><br>{chunks_data[-1]['found_items']}</p>"))
                display(HTML("<hr>"))

            # Show summary
            if all_found:
                display(HTML("<h3>All Found Items:</h3>"))
                display(HTML(f"<p>{', '.join(f'{label}: {item}' for label, item in all_found)}</p>"))

            # Create download link
            display(HTML(create_csv(chunks_data)))

        except Exception as e:
            print(f"Error: {e}")

# Setup widget interactions
add_name_button.on_click(on_add_name_clicked)
add_company_button.on_click(on_add_company_clicked)
process_button.on_click(on_process_clicked)

# Display widgets
display(HTML("<h2>Step 1: Upload a Document</h2>"))
display(upload)
display(HTML("<h2>Step 2: Add Manual Entries (Optional)</h2>"))
display(VBox([name_input, add_name_button, company_input, add_company_button]))
display(HTML("<h2>Step 3: Process Document</h2>"))
display(process_button)
display(output)

FileUpload(value=(), accept='.docx,.pdf', description='Upload')

VBox(children=(Text(value='', description='Name:'), Button(description='Add Name', style=ButtonStyle()), Text(…

Button(description='Process and Anonymize', style=ButtonStyle())

Output()