# Class Diagram Extraction Pipeline

This notebook reads text documents from the target_text directory and processes them to extract class diagrams.

In [None]:
# Import required libraries
import os
import re
import pandas as pd
import time
import torch
from tqdm import tqdm
from dotenv import load_dotenv
from openai import AzureOpenAI
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
from collections import defaultdict
from spacy import displacy
from torch.utils.data import Dataset, DataLoader

# Reading Files from target_text Directory

In [None]:
# Create output directory for PlantUML files
output_dir = os.path.join(os.path.dirname(os.getcwd()), "survey/class_diagrams_output_prosus")
if not os.path.exists(output_dir):
    os.makedirs(output_dir)
    print(f"Created output directory: {output_dir}")
else:
    print(f"Output directory already exists: {output_dir}")

# Define target directory path
target_dir = os.path.join(os.path.dirname(os.getcwd()), "survey/target_text")
print(f"Reading files from: {target_dir}")

# Check if directory exists
if not os.path.exists(target_dir):
    print(f"Error: Directory '{target_dir}' does not exist.")
    raise FileNotFoundError(f"Directory '{target_dir}' does not exist.")

# Get list of text files from the directory
text_files = [f for f in os.listdir(target_dir) if f.endswith('.txt')]

if not text_files:
    print("No text files found in the directory.")
else:
    print(f"Found {len(text_files)} text file(s): {text_files}")

# Process Each Text File

In [None]:
def preprocess_document(document_text, filename):
    """Preprocess the document text and split it into sentences"""
    print(f"Processing document: {filename}")
    
    # Data preprocessing
    print("Step 1: Preprocessing")
    # Remove newlines and extra spaces
    document = re.sub(r'\n', ' ', document_text)
    document = re.sub(r'\s+', ' ', document)
    
    # Separate document by sentences into a dataframe
    document_sentence = pd.DataFrame(document.split('.'), columns=['sentence'])
    document_sentence = document_sentence[document_sentence['sentence'].str.strip().str.len() > 0].reset_index(drop=True)
    
    return document_sentence

In [None]:
# Load the pre-trained BERT model for sentence classification
def load_sentence_classifier_model():
    """Load the pre-trained BERT model for sentence classification"""
    model_dir = os.path.join(os.path.dirname(os.getcwd()), "requirement_classification", "all_model")
    
    # If multiple models exist, use the most recently created one
    model_file = "ProsusAI-finbert_structure_focus_epochs9_kfold10_batch8.bin"
    model_path = os.path.join(model_dir, model_file)
    print(f"Loading sentence classifier model from: {model_path}")
    
    # Define model name based on file name pattern
    model_name = "ProsusAI/finbert"  # Default model base
    
    # Load the tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2, ignore_mismatched_sizes=True)
    model.load_state_dict(torch.load(model_path))
    model.eval()
    
    # Move model to GPU if available
    device = torch.device("cuda:1")
    model.to(device)
    
    return model, tokenizer, device

In [None]:
# TextDataset class for sentence classification
class TextDataset(Dataset):
    def __init__(self, texts, tokenizer, max_length=512):
        self.encodings = tokenizer(texts, padding=True, truncation=True, max_length=max_length, return_tensors='pt')
    
    def __len__(self):
        return len(self.encodings.input_ids)
    
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

In [None]:
def classify_sentences(document_sentence, filename, output_dir):
    """Classify sentences to identify those useful for class diagram extraction using BERT model"""
    print("Step 2: Categorizing sentences using pre-trained model")
    document_sentence['useful'] = 0
    
    # Load the pre-trained model
    try:
        model, tokenizer, device = load_sentence_classifier_model()
    except Exception as e:
        print(f"Error loading sentence classifier model: {e}")
        print("Falling back to default classification method")
        # You could implement a fallback method here
        return document_sentence
    
    # Create dataset from sentences
    sentences = document_sentence['sentence'].tolist()
    dataset = TextDataset(sentences, tokenizer)
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False)
    
    # Process batches and get predictions
    predictions = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Classifying sentences"):
            # Move inputs to device
            inputs = {key: val.to(device) for key, val in batch.items()}
            
            # Get model outputs
            outputs = model(**inputs)
            logits = outputs.logits
            
            # Get predicted class (0 = not useful, 1 = useful)
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.cpu().numpy())
    
    # Update dataframe with predictions
    document_sentence['useful'] = predictions
    
    # Log results
    useful_count = sum(predictions)
    print(f"Found {useful_count} useful sentences out of {len(sentences)} total sentences")
    
    # Save checkpoint
    checkpoint_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_checkpoint.csv")
    document_sentence.to_csv(checkpoint_file)
    print(f"Saved checkpoint to {checkpoint_file}")
    
    return document_sentence

In [None]:
def extract_entities(document_sentence, filename, output_dir):
    """Extract entities from useful sentences"""
    # Extract sentences marked as useful for class diagram extraction
    print("Step 3: Extracting class diagram entities")
    sentence_class_diagram_only = document_sentence[document_sentence['useful'] == 1]
    document_class = ' '.join(sentence_class_diagram_only['sentence'].tolist())
    
    # Entity extraction using the model
    model_path = os.path.join(os.path.dirname(os.getcwd()), "key-term-extraction", "BERT-Style-model/microsoft/deberta-v3-large-4-epoch-8-bs")
    print(f"Using NER model from: {model_path}")
    
    try:
        ner_pipeline = pipeline("ner", model=model_path, aggregation_strategy="simple")
        entities = ner_pipeline(document_class)
        
        # Process entities
        summary = {
            "class": defaultdict(int),
            "attr": defaultdict(int)
        }
        
        for ent in entities:
            entity_type = ent["entity_group"].lower()
            word = ent["word"]
            
            if entity_type in summary:
                summary[entity_type][word] += 1
        
        # Convert defaultdict to normal dict
        summary = {key: list(value.keys()) for key, value in summary.items()}
        
        # Save entity data as CSV
        entity_csv_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_entities.csv")
        pd.DataFrame(entities).to_csv(entity_csv_file)
        print(f"Entities saved to CSV: {entity_csv_file}")
        
        # Convert entities to JSON serializable format
        serializable_entities = []
        for ent in entities:
            # Extract only serializable properties and convert non-serializable types
            serializable_ent = {
                "entity_group": ent["entity_group"],
                "word": ent["word"],
                "score": float(ent["score"]),  # Convert tensor to float if needed
                "start": ent["start"],
                "end": ent["end"]
            }
            serializable_entities.append(serializable_ent)
        
        # Save entity data as JSON
        import json
        entity_json_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_entities.json")
        with open(entity_json_file, 'w', encoding='utf-8') as f:
            json.dump({
                "entities": serializable_entities,  # Use serializable entities
                "summary": summary,
                "document_class": document_class
            }, f, ensure_ascii=False, indent=2)
        print(f"Entities saved to JSON: {entity_json_file}")
        
        # Save entity data as plain text
        entity_txt_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_entities.txt")
        with open(entity_txt_file, 'w', encoding='utf-8') as f:
            f.write(f"Document: {filename}\n\n")
            f.write(f"Classes identified:\n{'-'*20}\n")
            for cls in summary['class']:
                f.write(f"- {cls}\n")
            f.write(f"\nAttributes identified:\n{'-'*20}\n")
            for attr in summary['attr']:
                f.write(f"- {attr}\n")
            f.write(f"\nDetailed Entities:\n{'-'*20}\n")
            for ent in serializable_entities:  # Use serializable entities
                f.write(f"Type: {ent['entity_group']}, Word: {ent['word']}, Score: {ent['score']:.4f}\n")
        print(f"Entities saved to TXT: {entity_txt_file}")
        
        # Return serializable entities for further processing
        return {
            "entities": serializable_entities,  # Use serializable entities
            "summary": summary,
            "document_class": document_class
        }
        
    except Exception as e:
        print(f"Error extracting entities: {e}")
        return {
            "error": str(e)
        }

In [None]:
def generate_diagram(extraction_result, filename, output_dir):
    """Generate class diagram from extracted entities"""
    print("Step 4: Generating PlantUML code")
    
    # Check if we have valid extraction results
    if "error" in extraction_result:
        print(f"Cannot generate diagram due to extraction error: {extraction_result['error']}")
        return {"filename": filename, "error": extraction_result['error']}
    
    try:
        # Prepare summary for diagram generation
        summary = extraction_result["summary"]
        document_class = extraction_result["document_class"]
        summary_string = f"class: {summary['class']}, attribute: {summary['attr']}, description: {document_class}"
        
        # Generate PlantUML using Azure OpenAI
        chat_prompt = [
            {
                "role": "system",
                "content": [{
                    "type": "text",
                    "text": "You will be given a JSON of class names, attributes, and a system description. Your task is to generate plantuml script containing classes, attributes, and relationships according to the system description. Strictly produce only plantuml script"
                }]
            },
            {
                "role": "user",
                "content": [{
                    "type": "text",
                    "text": summary_string
                }]
            }
        ]
        
        completion = client.chat.completions.create(
            model=deployment,
            messages=chat_prompt,
            max_tokens=800,
            temperature=0,
            top_p=0.95,
            frequency_penalty=0,
            presence_penalty=0,
            stop=None,
            stream=False
        )
        plantuml_result = completion.choices[0].message.content
        
        # Clean up the result and save to file
        plantuml_result = plantuml_result.strip('```plantuml')
        plantuml_result = plantuml_result.strip('```')
        
        output_file = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_class_diagram.puml")
        with open(output_file, "w", encoding="utf-8") as file:
            file.write(plantuml_result)
        print(f"PlantUML diagram saved to {output_file}")
        
        return {
            "filename": filename,
            "entities": extraction_result["entities"],
            "plantuml_code": plantuml_result,
            "output_file": output_file
        }
        
    except Exception as e:
        print(f"Error generating diagram: {e}")
        return {
            "filename": filename,
            "error": str(e)
        }

In [None]:
def process_document(document_text, filename):
    """Process a document through the entire pipeline"""
    # Step 1: Preprocess document
    document_sentence = preprocess_document(document_text, filename)
    
    # Step 2: Classify sentences
    document_sentence = classify_sentences(document_sentence, filename, output_dir)
    
    # Step 3: Extract entities
    extraction_result = extract_entities(document_sentence, filename, output_dir)

    return "true " + filename

In [None]:
# Process all files in the target directory
results = []

for filename in text_files:
    file_path = os.path.join(target_dir, filename)
    print(f"\n{'='*50}\nProcessing file: {filename}\n{'='*50}")
    
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            document_text = file.read()
            
            result = process_document(document_text, filename)
            results.append(result)
            
    except Exception as e:
        print(f"Error reading file {filename}: {e}")
        results.append({
            "filename": filename,
            "error": str(e)
        })
        
print("\nProcessing complete!")
print(f"Processed {len(results)} files.")
print(f"Output saved to: {output_dir}")

# Visualize Results

In [5]:
import os
import re
from dotenv import load_dotenv
from openai import AzureOpenAI

In [6]:
# Initialize Azure OpenAI client
load_dotenv()

endpoint = os.getenv("AZURE_OPENAI_ENDPOINT_URL_1", "")
deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME_1", "")
subscription_key = os.getenv("AZURE_OPENAI_API_KEY_1", "")
api_version = os.getenv("AZURE_OPENAI_API_VERSION_1")

print(f"Azure OpenAI endpoint: {endpoint}")
print(f"Azure OpenAI deployment: {deployment}")

# Initialize Azure OpenAI Service client with key-based authentication
client = AzureOpenAI(
    azure_endpoint=endpoint,
    api_key=subscription_key,
    api_version=api_version,
)

Azure OpenAI endpoint: https://dewi.openai.azure.com/
Azure OpenAI deployment: gpt-4o


In [24]:
def generate_class_diagram_with_azure(classes, attributes, document_text, filename):
    """
    Generate a class diagram using Azure OpenAI by sending a request with the extracted classes, attributes,
    and document text.
    
    Args:
        classes (list): List of extracted class names
        attributes (list): List of extracted attributes
        document_text (str): The full document text used for context
        filename (str): The name of the file being processed (without extension)
    
    Returns:
        str: PlantUML code for the class diagram
    """
    # Prepare summary for diagram generation
    summary_string = f"class: {classes}, attribute: {attributes}, description: {document_text}"
    
    # Generate PlantUML using Azure OpenAI
    chat_prompt = [
        {
            "role": "user",
            "content": [{
                "type": "text",
                "text":    '''Given a description of software requirement: {document_text}
There are list of entities
List of classes:  {classes}
List of attributes: {attributes}

Generate a Class Diagram based on the description and entities with following factor:
-	prioritize using entities given
-	Discover relationship between classes according to the description
-	Assign attributes to the classes according to the description
-   Since a class could be an attribute of another class, it is allowed when necessary
Set the output strictly to only PlantUML of the result diagram 
'''.format(document_text=document_text, classes=classes, attributes=attributes)
            }]
        },
    ]
    print(chat_prompt)
    # Use the existing client from previous cells
    completion = client.chat.completions.create(
        model=deployment,
        messages=chat_prompt,
        max_tokens=800,
        temperature=0,
        top_p=0.95,
        frequency_penalty=0,
        presence_penalty=0,
        stop=None,
        stream=False
    )
    
    # Extract and clean up the result
    plantuml_result = completion.choices[0].message.content
    
    # Clean up any markdown code block markers
    plantuml_result = plantuml_result.strip()
    plantuml_result = plantuml_result.replace("```plantuml", "").replace("```", "").strip()
    
    # Create plantuml_result folder if it doesn't exist
    plantuml_result_dir = "./plantuml_result_combine"
    if not os.path.exists(plantuml_result_dir):
        os.makedirs(plantuml_result_dir)
        print(f"Created directory: {plantuml_result_dir}")
    
    # Save PlantUML code to file
    output_file = os.path.join(plantuml_result_dir, f"{filename}_class_diagram.puml")
    with open(output_file, "w", encoding="utf-8") as file:
        file.write(plantuml_result)
    print(f"PlantUML diagram saved to {output_file}")
    
    return plantuml_result, output_file

In [25]:
import json 
import glob
import os 

output_dir = "./class_diagrams_output"
# Get all JSON files from the output directory
json_files = glob.glob(os.path.join(output_dir, "*_entities.json"))
print(f"Found {len(json_files)} JSON files")

# Parse each JSON file and extract the required information
extracted_data = []

for json_file in json_files:
    try:
        with open(json_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
            
        filename = os.path.basename(json_file).replace('_entities.json', '')
        
        # Extract document text, classes, and attributes
        document = data.get('document_class', '')
        classes = data.get('summary', {}).get('class', [])
        attributes = data.get('summary', {}).get('attr', [])
        
        generate_class_diagram_with_azure(classes, attributes, document, filename)
    except Exception as e:
        print(f"Error processing {json_file}: {e}")

print(f"Successfully extracted data from {len(extracted_data)} files")



Found 8 JSON files
[{'role': 'user', 'content': [{'type': 'text', 'text': "Given a description of software requirement: The clinic basically schedules patients, provides services for them, and bills them for those services  New patients fill out a form listing their name, address, telephone numbers, allergies, and state of mind prior to scheduling their first appointment  Schedules are entered into a central appointment book; patient records (including contact information) are kept in paper files  Appointments are for one of three procedures: dental hygiene, cavities and fillings, and oral surgery (including root canals and tooth extractions)  For each procedure the patient needs to be prepared and supplies need to be collected (e , probes, drill bits, cements, resins, etc  Billing is always done by the month, and bills are always sent by mail to patients' contact addresses  Each patient also generates a reimbursement request to an insurance company  The clinic maintains a supplies inv