In [None]:
import os
from pathlib import Path
import fitz  # PyMuPDF
import docx2txt
import re
import csv
from io import StringIO
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# Check PyMuPDF version
if not hasattr(fitz.Page, "find_tables"):
    logger.error("PyMuPDF version does not support table extraction. Please upgrade to version 1.23.0 or higher.")
    raise ImportError("PyMuPDF version too old. Run 'pip install --upgrade pymupdf'.")

# Paths
root_folder = Path(r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\abb_products_pdf_files")
root2_folder = Path(r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data")
output_folder = root2_folder / "processed_final_grok"
individual_docs_folder = output_folder / "individual_docs"
images_folder = output_folder / "extracted_images"
individual_docs_folder.mkdir(parents=True, exist_ok=True)
images_folder.mkdir(parents=True, exist_ok=True)

# Noise patterns
noise_patterns = [
    r"^Contents\b.*", r"^Page\s+\d+", r"^ABB\s*$",
    r"^Product version.*", r"^Issued.*", r"^Revision.*",
    r"^Figure\s+\d+\..*", r"^GUID-[\w\-]+", r"^© Copyright.*",
    r"^All rights reserved.*", r"^Trademarks.*", r"^Disclaimer.*",
    r"^http[s]?://.*", r"^\s*-+\s*$", r"^\s*$"
]
noise_regex = re.compile("|".join(noise_patterns), re.IGNORECASE)

# Utility functions
def is_near_image(block_rect, image_rects, margin=30):
    for img in image_rects:
        if (block_rect.intersects(img) or
            abs(block_rect.y0 - img.y1) < margin or
            abs(block_rect.y1 - img.y0) < margin):
            return True
    return False

def is_meaningful(line):
    words = line.split()
    if len(words) < 4:
        return False
    alpha_ratio = sum(c.isalpha() for c in line) / len(line)
    return alpha_ratio > 0.5

def extract_text_from_table_region(page, table_bbox):
    """Extract text from table region using text blocks method as fallback."""
    try:
        # Get text blocks within the table bounding box
        table_rect = fitz.Rect(table_bbox)
        blocks = page.get_text("dict", clip=table_rect)
        
        # Extract text from blocks and organize by position
        text_items = []
        for block in blocks.get("blocks", []):
            if "lines" in block:
                for line in block["lines"]:
                    for span in line["spans"]:
                        text = span["text"].strip()
                        if text:
                            bbox = span["bbox"]
                            text_items.append({
                                "text": text,
                                "x0": bbox[0],
                                "y0": bbox[1],
                                "x1": bbox[2],
                                "y1": bbox[3]
                            })
        
        # Sort by vertical position first, then horizontal
        text_items.sort(key=lambda x: (x["y0"], x["x0"]))
        
        # Group items into rows based on vertical position
        rows = []
        current_row = []
        current_y = None
        tolerance = 5  # pixels tolerance for same row
        
        for item in text_items:
            if current_y is None or abs(item["y0"] - current_y) <= tolerance:
                current_row.append(item)
                current_y = item["y0"] if current_y is None else current_y
            else:
                if current_row:
                    # Sort current row by horizontal position
                    current_row.sort(key=lambda x: x["x0"])
                    rows.append([item["text"] for item in current_row])
                current_row = [item]
                current_y = item["y0"]
        
        # Add the last row
        if current_row:
            current_row.sort(key=lambda x: x["x0"])
            rows.append([item["text"] for item in current_row])
        
        return rows
    except Exception as e:
        logger.warning(f"Error extracting text from table region: {e}")
        return []

def format_table_to_text(page, table):
    """Convert a table to a CSV-like text format with proper text extraction."""
    try:
        if not table:
            logger.warning("Empty table encountered.")
            return ""
        
        output = StringIO()
        writer = csv.writer(output, lineterminator='\n')
        
        # First, try to extract using PyMuPDF's table extraction
        try:
            table_data = table.extract()
            if table_data and len(table_data) > 0:
                # Check if we got actual text content
                sample_cell = str(table_data[0][0]) if table_data[0] else ""
                if not (sample_cell.startswith("(") and "," in sample_cell and sample_cell.endswith(")")):
                    # We got actual text, not coordinates
                    for row in table_data:
                        cleaned_row = []
                        for cell in row:
                            cell_text = str(cell).strip() if cell is not None else ""
                            cleaned_row.append(cell_text)
                        if any(cleaned_row):  # Only add non-empty rows
                            writer.writerow(cleaned_row)
                    
                    table_text = output.getvalue()
                    output.close()
                    return table_text if table_text.strip() else ""
        except Exception as e:
            logger.warning(f"PyMuPDF table extraction failed: {e}")
        
        # Fallback: Extract text from table region using coordinate-based method
        logger.info("Using fallback text extraction for table")
        table_bbox = table.bbox
        rows = extract_text_from_table_region(page, table_bbox)
        
        if rows:
            for row in rows:
                if any(cell.strip() for cell in row):  # Only add rows with content
                    writer.writerow(row)
            
            table_text = output.getvalue()
            output.close()
            return table_text if table_text.strip() else ""
        
        output.close()
        return ""
        
    except Exception as e:
        logger.error(f"Error formatting table: {e}")
        return ""

def extract_tables_with_textboxes(page):
    """Alternative table extraction using text positioning analysis."""
    try:
        # Get all text with detailed positioning
        text_dict = page.get_text("dict")
        
        # Collect all text elements with their positions
        text_elements = []
        for block in text_dict.get("blocks", []):
            if "lines" in block:
                for line in block["lines"]:
                    for span in line["spans"]:
                        text = span["text"].strip()
                        if text and len(text) > 1:  # Filter out single characters
                            bbox = span["bbox"]
                            text_elements.append({
                                "text": text,
                                "x0": bbox[0],
                                "y0": bbox[1],
                                "x1": bbox[2],
                                "y1": bbox[3],
                                "size": span["size"]
                            })
        
        # Group elements that might form tables (similar y-coordinates, aligned x-coordinates)
        potential_tables = []
        processed_elements = set()
        
        for i, elem in enumerate(text_elements):
            if i in processed_elements:
                continue
                
            # Look for elements at similar y-level (potential row)
            row_elements = [elem]
            processed_elements.add(i)
            
            for j, other_elem in enumerate(text_elements[i+1:], i+1):
                if j in processed_elements:
                    continue
                    
                # Check if elements are on the same row (similar y-coordinates)
                if abs(elem["y0"] - other_elem["y0"]) <= 3:
                    row_elements.append(other_elem)
                    processed_elements.add(j)
            
            # If we found multiple elements in a row, it might be a table row
            if len(row_elements) >= 2:
                row_elements.sort(key=lambda x: x["x0"])  # Sort by x-coordinate
                potential_tables.append([elem["text"] for elem in row_elements])
        
        return potential_tables
        
    except Exception as e:
        logger.warning(f"Alternative table extraction failed: {e}")
        return []

# Main PDF text and table extractor
def extract_clean_text_from_pdf(file_path):
    text_blocks = []
    try:
        doc = fitz.open(file_path)
        first_heading = None
        
        for page_num, page in enumerate(doc):
            # Capture the title on the first page
            if page_num == 0:
                blocks = page.get_text("dict")["blocks"]
                titles = [
                    block["lines"][0]["spans"][0]["text"].strip()
                    for block in blocks
                    if block.get("lines") and block["lines"][0]["spans"] and block["lines"][0]["spans"][0]["size"] > 15
                ]
                if titles:
                    first_heading = " ".join(titles).strip()
            
            # Extract images to avoid text near them
            image_rects = [fitz.Rect(img[1:5]) for img in page.get_images(full=True)]
            
            # Extract tables using improved method
            tables_found = False
            try:
                tables = page.find_tables()
                for table in tables:
                    table_text = format_table_to_text(page, table)
                    if table_text.strip():
                        text_blocks.append("=== TABLE START ===")
                        text_blocks.append(table_text.strip())
                        text_blocks.append("=== TABLE END ===")
                        tables_found = True
                        logger.info(f"Successfully extracted table from page {page_num + 1}")
            except Exception as e:
                logger.warning(f"Primary table extraction failed for page {page_num + 1}: {e}")
            
            # If no tables found with primary method, try alternative approach
            if not tables_found:
                try:
                    alt_tables = extract_tables_with_textboxes(page)
                    if alt_tables and len(alt_tables) >= 2:  # At least 2 rows to consider it a table
                        text_blocks.append("=== TABLE START ===")
                        output = StringIO()
                        writer = csv.writer(output, lineterminator='\n')
                        for row in alt_tables:
                            writer.writerow(row)
                        table_text = output.getvalue()
                        output.close()
                        text_blocks.append(table_text.strip())
                        text_blocks.append("=== TABLE END ===")
                        logger.info(f"Alternative table extraction found {len(alt_tables)} rows on page {page_num + 1}")
                except Exception as e:
                    logger.warning(f"Alternative table extraction failed for page {page_num + 1}: {e}")
            
            # Extract non-tabular text
            for block in page.get_text("blocks"):
                rect = fitz.Rect(block[:4])
                text = block[4].strip()
                if is_near_image(rect, image_rects) or not text:
                    continue
                for line in text.splitlines():
                    line = line.strip()
                    if noise_regex.match(line):
                        continue
                    if is_meaningful(line):
                        text_blocks.append(line)
        
        doc.close()
        if first_heading:
            return f"{first_heading}\n\n" + "\n".join(text_blocks)
        return "\n".join(text_blocks)
    except Exception as e:
        logger.error(f"Error processing {file_path}: {e}")
        return ""

# DOCX extractor
def extract_text_from_docx(file_path):
    try:
        return docx2txt.process(file_path)
    except Exception as e:
        logger.error(f"Error reading DOCX {file_path}: {e}")
        return ""

# Image extractor
def extract_images_from_pdf(pdf_path, base_name):
    try:
        doc = fitz.open(pdf_path)
        for page_num in range(len(doc)):
            for img_index, img in enumerate(doc.get_page_images(page_num)):
                xref = img[0]
                base_image = doc.extract_image(xref)
                image_bytes = base_image["image"]
                image_ext = base_image["ext"]
                image_path = images_folder / f"{base_name}_p{page_num+1}_img{img_index+1}.{image_ext}"
                with open(image_path, "wb") as f:
                    f.write(image_bytes)
        doc.close()
    except Exception as e:
        logger.error(f"Error extracting image from {pdf_path}: {e}")

# Process all files
all_files = list(root_folder.rglob("*.pdf")) + list(root_folder.rglob("*.docx"))
merged_corpus = []

logger.info(f"Found {len(all_files)} files to process")

for i, file in enumerate(all_files):
    name = f"{i:04d}__{file.stem}".replace(" ", "_").replace("/", "_")
    logger.info(f"Processing file {i+1}/{len(all_files)}: {file.name}")
    
    try:
        if file.suffix.lower() == ".pdf":
            raw_text = extract_clean_text_from_pdf(file)
            extract_images_from_pdf(file, name)
        elif file.suffix.lower() == ".docx":
            raw_text = extract_text_from_docx(file)
        else:
            logger.warning(f"Skipping unsupported file type: {file}")
            continue

        if raw_text.strip():
            txt_path = individual_docs_folder / f"{name}.txt"
            with open(txt_path, "w", encoding="utf-8") as f:
                f.write(raw_text.strip())
            merged_corpus.append(raw_text.strip())
            logger.info(f"✅ Successfully processed: {file.name}")
        else:
            logger.warning(f"No meaningful content extracted from {file}")
    except Exception as e:
        logger.error(f"Failed to process {file}: {e}")
        continue

# Save final merged corpus
merged_path = output_folder / "abb_corpus.txt"
try:
    with open(merged_path, "w", encoding="utf-8") as f:
        for doc in merged_corpus:
            f.write(doc + "\n\n" + "="*80 + "\n\n")
    logger.info(f"✅ Merged corpus written to: {merged_path}")
except Exception as e:
    logger.error(f"Error writing merged corpus to {merged_path}: {e}")

logger.info(f"✅ Final preprocessing complete!")
logger.info(f"📄 Individual text files saved to: {individual_docs_folder}")
logger.info(f"🖼️ Images extracted to: {images_folder}")
logger.info(f"📘 Merged corpus written to: {merged_path}")
logger.info(f"📊 Total files processed: {len(merged_corpus)}/{len(all_files)}")

In [1]:
import os
import json
import re
from pathlib import Path
from typing import List, Dict, Tuple
import logging
from collections import Counter
import matplotlib.pyplot as plt
import seaborn as sns

# Tokenizer libraries
from tokenizers import Tokenizer, models, pre_tokenizers, decoders, trainers, processors
from tokenizers.models import BPE, WordPiece
from tokenizers.trainers import BpeTrainer, WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace, ByteLevel
from tokenizers.processors import TemplateProcessing
from tokenizers.decoders import ByteLevel as ByteLevelDecoder, WordPiece as WordPieceDecoder



In [5]:
# Set up logging
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

class ABBTokenizerTrainer:
    """Custom tokenizer trainer for ABB technical documentation."""
    
    def __init__(self, corpus_path: str, output_dir: str):
        self.corpus_path = Path(corpus_path)
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(parents=True, exist_ok=True)
        
        # Tokenizer configurations
        self.vocab_sizes = [8000, 16000, 32000]
        self.special_tokens = ["<pad>", "<unk>", "<bos>", "<eos>", "<mask>", "=== TABLE START ===", "=== TABLE END ==="]
        
        # Statistics storage
        self.training_stats = {}
        
    def preprocess_corpus(self) -> Tuple[str, Dict]:
        """Clean and preprocess the ABB corpus for tokenizer training."""
        logger.info("Preprocessing ABB corpus...")
        
        if not self.corpus_path.exists():
            raise FileNotFoundError(f"Corpus file not found: {self.corpus_path}")
        
        with open(self.corpus_path, 'r', encoding='utf-8') as f:
            raw_text = f.read()
        
        # Clean the text
        cleaned_text = self._clean_text(raw_text)
        
        # Generate statistics
        stats = self._generate_corpus_stats(cleaned_text)
        
        # Save cleaned corpus
        cleaned_path = self.output_dir / "cleaned_corpus.txt"
        with open(cleaned_path, 'w', encoding='utf-8') as f:
            f.write(cleaned_text)
        
        logger.info(f"Cleaned corpus saved to: {cleaned_path}")
        return str(cleaned_path), stats
    
    def _clean_text(self, text: str) -> str:
        """Clean and normalize text for tokenizer training."""
        # Remove table markers
        text = re.sub(r'=== TABLE START ===.*?=== TABLE END ===', '', text, flags=re.DOTALL)
        text = re.sub(r'=== Table ===.*?=== End Table ===', '', text, flags=re.DOTALL)
        
        # Clean up document separators
        text = re.sub(r'={50,}', '\n', text)
        
        # Normalize whitespace
        text = re.sub(r'\n\s*\n', '\n\n', text)  # Multiple newlines to double newline
        text = re.sub(r'[ \t]+', ' ', text)  # Multiple spaces/tabs to single space
        
        # Preserve ABB product codes and technical terms
        # Keep alphanumeric codes like "ARC600", "IEC-104", etc.
        
        # Remove excessive punctuation
        text = re.sub(r'[.]{3,}', '...', text)
        text = re.sub(r'[-]{3,}', '---', text)
        
        # Clean up common formatting artifacts
        text = re.sub(r'\s+([.,;:!?])', r'\1', text)  # Remove space before punctuation
        
        return text.strip()
    
    def _generate_corpus_stats(self, text: str) -> Dict:
        """Generate statistics about the corpus."""
        words = text.split()
        chars = list(text)
        
        stats = {
            'total_characters': len(text),
            'total_words': len(words),
            'unique_words': len(set(words)),
            'unique_chars': len(set(chars)),
            'avg_word_length': sum(len(word) for word in words) / len(words) if words else 0,
            'vocabulary_size': len(set(word.lower() for word in words)),
        }
        
        # Find ABB-specific terms
        abb_terms = []
        patterns = [
            r'\bABB\b',
            r'\b[A-Z]{2,}\d+\b',  # Product codes like ARC600
            r'\b[A-Z]+-\d+\b',    # Standards like IEC-104
            r'\b\d+[A-Z]+\b',     # Model numbers
        ]
        
        for pattern in patterns:
            matches = re.findall(pattern, text)
            abb_terms.extend(matches)
        
        stats['abb_terms_count'] = len(abb_terms)
        stats['unique_abb_terms'] = len(set(abb_terms))
        
        return stats
    
    def train_bpe_tokenizer(self, corpus_path: str, vocab_size: int) -> Tokenizer:
        """Train a BPE tokenizer on the ABB corpus."""
        logger.info(f"Training BPE tokenizer with vocab size {vocab_size}...")
        
        # Initialize BPE tokenizer
        tokenizer = Tokenizer(BPE(unk_token="<unk>"))
        
        # Set pre-tokenizer
        tokenizer.pre_tokenizer = pre_tokenizers.Sequence([
            Whitespace(),
            ByteLevel(add_prefix_space=False)
        ])
        
        # Set decoder
        tokenizer.decoder = ByteLevelDecoder()
        
        # Configure trainer
        trainer = BpeTrainer(
            vocab_size=vocab_size,
            min_frequency=2,
            special_tokens=self.special_tokens,
            show_progress=True,
            initial_alphabet=ByteLevel.alphabet()
        )
        
        # Train the tokenizer
        tokenizer.train([corpus_path], trainer)
        
        # Add post-processor for special tokens
        tokenizer.post_processor = TemplateProcessing(
            single="<bos> $A <eos>",
            special_tokens=[("<bos>", 2), ("<eos>", 3)]
        )
        
        return tokenizer
    
    def train_wordpiece_tokenizer(self, corpus_path: str, vocab_size: int) -> Tokenizer:
        """Train a WordPiece tokenizer on the ABB corpus."""
        logger.info(f"Training WordPiece tokenizer with vocab size {vocab_size}...")
        
        # Initialize WordPiece tokenizer
        tokenizer = Tokenizer(WordPiece(unk_token="<unk>"))
        
        # Set pre-tokenizer
        tokenizer.pre_tokenizer = Whitespace()
        
        # Set decoder
        tokenizer.decoder = WordPieceDecoder()
        
        # Configure trainer
        trainer = WordPieceTrainer(
            vocab_size=vocab_size,
            min_frequency=2,
            special_tokens=self.special_tokens,
            show_progress=True
        )
        
        # Train the tokenizer
        tokenizer.train([corpus_path], trainer)
        
        # Add post-processor for special tokens
        tokenizer.post_processor = TemplateProcessing(
            single="<bos> $A <eos>",
            special_tokens=[("<bos>", 2), ("<eos>", 3)]
        )
        
        return tokenizer
    
    def evaluate_tokenizer(self, tokenizer: Tokenizer, test_text: str, name: str) -> Dict:
        """Evaluate tokenizer performance on test text."""
        logger.info(f"Evaluating {name} tokenizer...")
        
        # Encode test text
        encoding = tokenizer.encode(test_text)
        tokens = encoding.tokens
        
        # Calculate metrics
        metrics = {
            'name': name,
            'num_tokens': len(tokens),
            'num_characters': len(test_text),
            'compression_ratio': len(test_text) / len(tokens) if tokens else 0,
            'vocab_size': tokenizer.get_vocab_size(),
            'unk_count': sum(1 for token in tokens if token == '<unk>'),
            'unk_ratio': sum(1 for token in tokens if token == '<unk>') / len(tokens) if tokens else 0
        }
        
        # Sample tokenization examples
        sample_texts = [
            "ABB Wireless Controller ARC600 provides remote monitoring capabilities.",
            "The IEC-104 protocol ensures reliable communication in distribution networks.",
            "System Average Interruption Duration Index (SAIDI) measurements show improvement.",
            "Configure the switching devices using I/O expansion modules."
        ]
        
        examples = []
        for text in sample_texts:
            encoding = tokenizer.encode(text)
            examples.append({
                'text': text,
                'tokens': encoding.tokens,
                'token_count': len(encoding.tokens)
            })
        
        metrics['examples'] = examples
        return metrics
    
    def save_tokenizer(self, tokenizer: Tokenizer, name: str, vocab_size: int):
        """Save tokenizer to files."""
        tokenizer_dir = self.output_dir / f"{name}_tokenizer_{vocab_size}"
        tokenizer_dir.mkdir(exist_ok=True)
        
        # Save tokenizer
        tokenizer.save(str(tokenizer_dir / "tokenizer.json"))
        
        # Save vocabulary
        vocab = tokenizer.get_vocab()
        with open(tokenizer_dir / "vocab.json", 'w', encoding='utf-8') as f:
            json.dump(vocab, f, indent=2, ensure_ascii=False)
        
        # Save vocab as text file
        with open(tokenizer_dir / "vocab.txt", 'w', encoding='utf-8') as f:
            for token, idx in sorted(vocab.items(), key=lambda x: x[1]):
                f.write(f"{token}\n")
        
        logger.info(f"Tokenizer saved to: {tokenizer_dir}")
        return tokenizer_dir
    
    def compare_tokenizers(self, metrics_list: List[Dict]):
        """Compare different tokenizers and generate report."""
        logger.info("Generating tokenizer comparison report...")
        
        # Create comparison DataFrame-like structure
        comparison_data = []
        for metrics in metrics_list:
            comparison_data.append({
                'Name': metrics['name'],
                'Vocabulary Size': metrics['vocab_size'],
                'Compression Ratio': f"{metrics['compression_ratio']:.2f}",
                'Unknown Token %': f"{metrics['unk_ratio']*100:.2f}%",
                'Tokens Generated': metrics['num_tokens']
            })
        
        # Save comparison report
        report_path = self.output_dir / "tokenizer_comparison.json"
        with open(report_path, 'w', encoding='utf-8') as f:
            json.dump({
                'comparison_summary': comparison_data,
                'detailed_metrics': metrics_list
            }, f, indent=2, ensure_ascii=False)
        
        # Generate text report
        text_report_path = self.output_dir / "tokenizer_comparison.txt"
        with open(text_report_path, 'w', encoding='utf-8') as f:
            f.write("ABB Tokenizer Comparison Report\n")
            f.write("=" * 50 + "\n\n")
            
            for data in comparison_data:
                f.write(f"Tokenizer: {data['Name']}\n")
                f.write(f"  Vocabulary Size: {data['Vocabulary Size']}\n")
                f.write(f"  Compression Ratio: {data['Compression Ratio']} chars/token\n")
                f.write(f"  Unknown Token Rate: {data['Unknown Token %']}\n")
                f.write(f"  Tokens Generated: {data['Tokens Generated']}\n\n")
            
            # Add recommendations
            f.write("Recommendations:\n")
            f.write("-" * 20 + "\n")
            
            best_compression = max(metrics_list, key=lambda x: x['compression_ratio'])
            best_coverage = min(metrics_list, key=lambda x: x['unk_ratio'])
            
            f.write(f"Best Compression: {best_compression['name']} ({best_compression['compression_ratio']:.2f} chars/token)\n")
            f.write(f"Best Coverage: {best_coverage['name']} ({best_coverage['unk_ratio']*100:.2f}% unknown tokens)\n")
        
        logger.info(f"Comparison reports saved to: {report_path} and {text_report_path}")
    
    def train_all_tokenizers(self):
        """Train all tokenizer variants and compare them."""
        logger.info("Starting comprehensive tokenizer training...")
        
        # Preprocess corpus
        cleaned_corpus_path, corpus_stats = self.preprocess_corpus()
        
        # Save corpus statistics
        stats_path = self.output_dir / "corpus_statistics.json"
        with open(stats_path, 'w', encoding='utf-8') as f:
            json.dump(corpus_stats, f, indent=2)
        
        logger.info(f"Corpus statistics: {corpus_stats}")
        
        # Load test text for evaluation
        with open(cleaned_corpus_path, 'r', encoding='utf-8') as f:
            full_text = f.read()
        
        # Use first 10% of text for evaluation
        test_text = full_text[:len(full_text)//10]
        
        all_metrics = []
        
        # Train tokenizers for different vocab sizes
        for vocab_size in self.vocab_sizes:
            # Train BPE tokenizer
            bpe_tokenizer = self.train_bpe_tokenizer(cleaned_corpus_path, vocab_size)
            bpe_metrics = self.evaluate_tokenizer(bpe_tokenizer, test_text, f"BPE_{vocab_size}")
            all_metrics.append(bpe_metrics)
            self.save_tokenizer(bpe_tokenizer, "BPE", vocab_size)
            
            # Train WordPiece tokenizer
            wp_tokenizer = self.train_wordpiece_tokenizer(cleaned_corpus_path, vocab_size)
            wp_metrics = self.evaluate_tokenizer(wp_tokenizer, test_text, f"WordPiece_{vocab_size}")
            all_metrics.append(wp_metrics)
            self.save_tokenizer(wp_tokenizer, "WordPiece", vocab_size)
        
        # Generate comparison report
        self.compare_tokenizers(all_metrics)
        
        logger.info("✅ Tokenizer training complete!")
        logger.info(f"📁 All outputs saved to: {self.output_dir}")
        
        return all_metrics



In [6]:
def main():
    """Main function to run tokenizer training."""
    
    # Configuration
    corpus_path = r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\abb_corpus.txt"
    output_dir = r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenizers"
    
    # Initialize trainer
    trainer = ABBTokenizerTrainer(corpus_path, output_dir)
    
    try:
        # Train all tokenizers
        metrics = trainer.train_all_tokenizers()
        
        # Print summary
        print("\n" + "="*60)
        print("TOKENIZER TRAINING SUMMARY")
        print("="*60)
        
        for metric in metrics:
            print(f"\n{metric['name']}:")
            print(f"  Vocabulary Size: {metric['vocab_size']:,}")
            print(f"  Compression Ratio: {metric['compression_ratio']:.2f} chars/token")
            print(f"  Unknown Token Rate: {metric['unk_ratio']*100:.2f}%")
            
            # Show first example
            if metric['examples']:
                example = metric['examples'][0]
                print(f"  Example: '{example['text']}'")
                print(f"  Tokens: {example['tokens'][:5]}..." if len(example['tokens']) > 5 else f"  Tokens: {example['tokens']}")
        
        print(f"\n📁 All tokenizer files saved to: {output_dir}")
        print("✅ Ready for Step 3: SLM Architecture Building!")
        
    except Exception as e:
        logger.error(f"Error during tokenizer training: {e}")
        raise

if __name__ == "__main__":
    main()

2025-06-10 14:33:04,042 - INFO - Starting comprehensive tokenizer training...
2025-06-10 14:33:04,042 - INFO - Preprocessing ABB corpus...
2025-06-10 14:33:08,276 - INFO - Cleaned corpus saved to: C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenizers\cleaned_corpus.txt
2025-06-10 14:33:08,285 - INFO - Corpus statistics: {'total_characters': 19322004, 'total_words': 3157414, 'unique_words': 56324, 'unique_chars': 258, 'avg_word_length': 5.113947046538718, 'vocabulary_size': 49980, 'abb_terms_count': 37190, 'unique_abb_terms': 1559}
2025-06-10 14:33:08,397 - INFO - Training BPE tokenizer with vocab size 8000...
2025-06-10 14:33:10,097 - INFO - Evaluating BPE_8000 tokenizer...
2025-06-10 14:33:11,301 - INFO - Tokenizer saved to: C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenizers\BPE_tokenizer_8000
2025-06-10 14:33:11,303 - INFO - Training WordPiece tokenizer with vocab size 8000...
2025-06-10 14


TOKENIZER TRAINING SUMMARY

BPE_8000:
  Vocabulary Size: 8,000
  Compression Ratio: 4.77 chars/token
  Unknown Token Rate: 0.00%
  Example: 'ABB Wireless Controller ARC600 provides remote monitoring capabilities.'
  Tokens: ['<bos>', 'ABB', 'Wireless', 'Control', 'ler']...

WordPiece_8000:
  Vocabulary Size: 8,000
  Compression Ratio: 4.72 chars/token
  Unknown Token Rate: 0.00%
  Example: 'ABB Wireless Controller ARC600 provides remote monitoring capabilities.'
  Tokens: ['<bos>', 'ABB', 'Wireless', 'Control', '##ler']...

BPE_16000:
  Vocabulary Size: 16,000
  Compression Ratio: 5.00 chars/token
  Unknown Token Rate: 0.00%
  Example: 'ABB Wireless Controller ARC600 provides remote monitoring capabilities.'
  Tokens: ['<bos>', 'ABB', 'Wireless', 'Controller', 'ARC']...

WordPiece_16000:
  Vocabulary Size: 16,000
  Compression Ratio: 5.03 chars/token
  Unknown Token Rate: 0.00%
  Example: 'ABB Wireless Controller ARC600 provides remote monitoring capabilities.'
  Tokens: ['<bos>', 'AB

In [1]:
from tokenizers import Tokenizer

# Load your tokenizer
tokenizer_path = r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenizers\WordPiece_tokenizer_32000\tokenizer.json"
tokenizer = Tokenizer.from_file(tokenizer_path)

# Sample text
text = "Wireless Controller ARC600 supports IEC-104 protocol succeeding."

# Encode and get both tokens and ids
encoding = tokenizer.encode(text)

tokens = encoding.tokens
token_ids = encoding.ids

# Print side-by-side
print("🧾 Tokenized Output:")
for token, id_ in zip(tokens, token_ids):
    print(f"{token:20} --> {id_}")


🧾 Tokenized Output:
<bos>                --> 2
Wireless             --> 7607
Controller           --> 10178
ARC600               --> 9841
supports             --> 2174
IEC                  --> 767
-                    --> 41
104                  --> 3378
protocol             --> 1010
succeeding           --> 27857
.                    --> 42
<eos>                --> 3


In [2]:
from pathlib import Path
import torch

corpus_file = Path(r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\abb_corpus.txt")
with open(corpus_file, 'r', encoding='utf-8') as f:
    text = f.read()

# Tokenize
token_ids = tokenizer.encode(text).ids
token_ids = torch.tensor(token_ids, dtype=torch.long)

# Save tokenized tensor to file
save_path = Path(r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenized_corpus.pt")
torch.save(token_ids, save_path)

print(f"✅ Tokenized corpus shape: {token_ids.shape}")
print(f"💾 Saved tokenized tensor to: {save_path}")

# Print the first 20 tokens from the corpus
print("\n=== First 20 Tokens ===")
first_20_ids = token_ids[:20].tolist()
first_20_tokens = tokenizer.encode(text[:100]).tokens[:20]  # Approximate mapping to tokens

print("Token ID  |  Token")
print("-" * 30)
for i, (token_id, token) in enumerate(zip(first_20_ids, first_20_tokens)):
    print(f"{token_id:<9} |  {token}")

✅ Tokenized corpus shape: torch.Size([7465108])
💾 Saved tokenized tensor to: C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenized_corpus.pt

=== First 20 Tokens ===
Token ID  |  Token
------------------------------
2         |  <bos>
178       |  —
5         |  === TABLE START ===
7607      |  Wireless
10178     |  Controller
40        |  ,
23973     |  1MRS758
15324     |  ##46
313       |  ##5
68        |  H
9841      |  ARC600
40        |  ,
2594      |  Product
1607      |  version
54        |  :
47        |  3
42        |  .
48        |  4
40        |  ,
6         |  ==


In [None]:
# Print tokens and IDs for a sample text
sample_text = "Wireless Controller ARC600 supports IEC-104 protocol."
encoding = tokenizer.encode(sample_text)
print("Tokens:", encoding.tokens)
print("Token IDs:", encoding.ids)

# Decode back (if supported)
if hasattr(tokenizer, "decode"):
    decoded = tokenizer.decode(encoding.ids)
    print("Decoded text:", decoded)
else:
    print("Tokenizer does not support decoding.")

# Check a slice of your full token_ids and print tokens with IDs
first_20_ids = token_ids[:20].tolist()
first_20_tokens = tokenizer.decode(first_20_ids).split()  # This may not be perfect for all tokenizers

print("\nFirst 20 tokens and their IDs from corpus:")
for token, id_ in zip(first_20_tokens, first_20_ids):
    print(f"{token:20} --> {id_}")

In [3]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import os
from pathlib import Path
from tokenizers import Tokenizer

class TransformerConfig:
    def __init__(
        self,
        vocab_size=32000,
        block_size=128,      # Reduced block size for faster training
        n_layer=4,           # Moderate depth
        n_head=8,            # Balanced attention heads
        n_embd=384,          # Efficient embedding size
        dropout=0.1,
        table_start_id=None,
        table_end_id=None
    ):
        self.vocab_size = vocab_size
        self.block_size = block_size
        self.n_layer = n_layer
        self.n_head = n_head
        self.n_embd = n_embd
        self.dropout = dropout
        self.table_start_id = table_start_id
        self.table_end_id = table_end_id

class TableAwareGPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Token and position embeddings
        self.token_emb = nn.Embedding(config.vocab_size, config.n_embd)
        self.pos_emb = nn.Parameter(torch.zeros(1, config.block_size, config.n_embd))
        
        # Content type embedding for tables vs text
        self.content_type_emb = nn.Embedding(2, config.n_embd)  # 0=text, 1=table
        
        # Store table marker token IDs
        self.table_start_id = config.table_start_id
        self.table_end_id = config.table_end_id
        
        # Dropout
        self.drop = nn.Dropout(config.dropout)
        
        # Transformer blocks
        self.blocks = nn.Sequential(*[
            nn.TransformerEncoderLayer(
                d_model=config.n_embd,
                nhead=config.n_head,
                dim_feedforward=4 * config.n_embd,
                activation='gelu',
                batch_first=True,
                dropout=config.dropout
            ) for _ in range(config.n_layer)
        ])
        
        # Final layer norm and head
        self.ln_f = nn.LayerNorm(config.n_embd)
        self.head = nn.Linear(config.n_embd, config.vocab_size)
        
        self.block_size = config.block_size
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, nn.LayerNorm):
            nn.init.zeros_(module.bias)
            nn.init.ones_(module.weight)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        if T > self.block_size:
            raise ValueError(f"Sequence length {T} > block size {self.block_size}")

        # Create content type labels (0=text, 1=table)
        content_types = torch.zeros((B, T), dtype=torch.long, device=idx.device)
        
        # Identify table regions based on special tokens
        if self.table_start_id is not None and self.table_end_id is not None:
            for b in range(B):
                in_table = False
                for t in range(T):
                    if idx[b, t] == self.table_start_id:
                        in_table = True
                    elif idx[b, t] == self.table_end_id:
                        in_table = False
                    
                    # Mark as table content when inside a table
                    if in_table:
                        content_types[b, t] = 1
        
        # Get embeddings
        token_embeddings = self.token_emb(idx)         # (B, T, n_embd)
        pos_embeddings = self.pos_emb[:, :T, :]        # (1, T, n_embd)
        type_embeddings = self.content_type_emb(content_types)  # (B, T, n_embd)
        
        # Combine embeddings
        x = token_embeddings + pos_embeddings + type_embeddings
        x = self.drop(x)
        
        # Transformer blocks
        x = self.blocks(x)
        x = self.ln_f(x)
        
        # Project to vocab
        logits = self.head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))

        return logits, loss
    
    def generate(self, idx, max_new_tokens=100, temperature=1.0, top_k=None):
        """Generate text using the trained model"""
        self.eval()
        idx = idx.clone()
        
        with torch.no_grad():
            for _ in range(max_new_tokens):
                # Get the last block_size tokens
                idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:]
                
                # Forward pass
                logits, _ = self.forward(idx_cond)
                
                # Get the logits for the last position
                logits = logits[:, -1, :] / temperature
                
                # Optional top-k sampling
                if top_k is not None:
                    v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = float('-inf')
                
                # Apply softmax to get probabilities
                probs = F.softmax(logits, dim=-1)
                
                # Sample from the distribution
                next_token = torch.multinomial(probs, num_samples=1)
                
                # Append to the sequence
                idx = torch.cat((idx, next_token), dim=1)
        
        return idx

class TokenDataset(Dataset):
    """Dataset for training the language model"""
    def __init__(self, tokens, block_size):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return max(0, len(self.tokens) - self.block_size)

    def __getitem__(self, idx):
        # Add bounds checking
        if idx < 0 or idx + self.block_size >= len(self.tokens):
            raise IndexError("Index out of bounds")
        x = self.tokens[idx:idx + self.block_size]
        y = self.tokens[idx + 1:idx + self.block_size + 1]
        return x, y

def count_parameters(model):
    """Count trainable parameters in the model"""
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"✅ Total trainable parameters: {total_params:,}")
    return total_params

def train_model(
    model, 
    token_ids, 
    batch_size=256,
    learning_rate=5e-4,
    num_epochs=3,
    grad_accum_steps=2,
    max_tokens=500_000,
    device="cuda" if torch.cuda.is_available() else "cpu"
):
    # Add verification of token IDs before creating dataset
    vocab_size = model.config.vocab_size
    
    # Check for out-of-range tokens
    max_id = token_ids.max().item()
    min_id = token_ids.min().item()
    
    if max_id >= vocab_size or min_id < 0:
        print(f"⚠️ Warning: Token IDs out of range. Max: {max_id}, Min: {min_id}, Vocab Size: {vocab_size}")
        print("🔄 Clamping token IDs to valid range...")
        token_ids = torch.clamp(token_ids, min=0, max=vocab_size-1)
    
    # Rest of your function as before
    # Sample dataset if too large
    if len(token_ids) > max_tokens:
        print(f"⚙️ Sampling {max_tokens:,} tokens from dataset of {len(token_ids):,} tokens")
        token_ids = token_ids[:max_tokens]
    
    # Create dataset and dataloader
    dataset = TokenDataset(token_ids, model.block_size)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=True
    )
    
    # Continue with the rest of the function...
    
    # Optimizer and scheduler
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(dataloader) * num_epochs // grad_accum_steps
    )
    
    # Move model to device
    model.to(device)
    
    # Create checkpoint directory
    checkpoint_dir = "table_aware_checkpoints"
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Training loop
    print(f"🚀 Starting training on {device} with batch size {batch_size}")
    print(f"📊 Dataset size: {len(dataset):,} samples, {len(dataloader):,} batches")
    
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}")
        
        # Reset gradient for gradient accumulation
        optimizer.zero_grad()
        
        for i, batch in enumerate(pbar):
            x, y = batch
            x = x.to(device)
            y = y.to(device)
            
            # Forward pass
            logits, loss = model(x, y)
            
            # Scale loss for gradient accumulation
            loss = loss / grad_accum_steps
            loss.backward()
            
            # Only update every grad_accum_steps
            if (i + 1) % grad_accum_steps == 0:
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()
            
            # Track loss (scaled back up for reporting)
            total_loss += loss.item() * grad_accum_steps
            pbar.set_postfix(loss=loss.item() * grad_accum_steps)
        
        # Final update if dataset size not divisible by grad_accum_steps
        if (i + 1) % grad_accum_steps != 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()
        
        # Calculate average loss for this epoch
        avg_loss = total_loss / len(dataloader)
        print(f"✅ Epoch {epoch+1} complete. Avg Loss: {avg_loss:.4f}")
        
        # Save checkpoint
        ckpt_path = os.path.join(checkpoint_dir, f"model_epoch{epoch+1}.pt")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss
        }, ckpt_path)
        print(f"💾 Checkpoint saved to: {ckpt_path}")
    
    print("✅ Training complete!")
    return model





In [4]:
def generate_text(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
    """Generate text using the trained model"""
    # Encode prompt
    prompt_ids = tokenizer.encode(prompt).ids
    prompt_tensor = torch.tensor([prompt_ids], dtype=torch.long)
    prompt_tensor = prompt_tensor.to(next(model.parameters()).device)
    
    # Generate
    output_ids = model.generate(
        prompt_tensor, 
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        top_k=40
    )
    
    # Decode
    output_text = tokenizer.decode(output_ids[0].tolist())
    return output_text

In [6]:

# Load tokenizer
tokenizer_path = r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenizers\WordPiece_tokenizer_32000\tokenizer.json"
tokenizer = Tokenizer.from_file(tokenizer_path)

# Get special token IDs
table_start_id = tokenizer.token_to_id("=== TABLE START ===")
table_end_id = tokenizer.token_to_id("=== TABLE END ===")

# Load tokenized corpus
corpus_path = r"C:\Users\INKARED5\OneDrive - ABB\Karan_ABB_Internship\Projects\Data\processed_final_grok\tokenized_corpus.pt"
token_ids = torch.load(corpus_path)
print(f"📄 Loaded tokenized corpus with {len(token_ids):,} tokens")

# Configure model
config = TransformerConfig(
    vocab_size=tokenizer.get_vocab_size(),
    block_size=128,     # Smaller block size
    n_layer=1,          # Fewer layers
    n_head=1,
    n_embd=128,         # Smaller embedding size
    table_start_id=table_start_id,
    table_end_id=table_end_id
)

# Create model
model = TableAwareGPT(config)
print("🤖 Model created with the following configuration:")
print(f"  - Vocab size: {config.vocab_size}")
print(f"  - Context length: {config.block_size} tokens")
print(f"  - Layers: {config.n_layer}")
print(f"  - Attention heads: {config.n_head}")
print(f"  - Embedding dim: {config.n_embd}")

# Count parameters
num_params = count_parameters(model)
print(f"💪 Model size: {num_params / 1_000_000:.2f}M parameters")

# Train model
trained_model = train_model(
    model=model,
    token_ids=token_ids,
    batch_size=128,
    learning_rate=5e-4,
    num_epochs=1,
    grad_accum_steps=2,
    max_tokens=10_000  # Limit dataset size for faster training
)



📄 Loaded tokenized corpus with 7,465,108 tokens
🤖 Model created with the following configuration:
  - Vocab size: 29431
  - Context length: 128 tokens
  - Layers: 1
  - Attention heads: 1
  - Embedding dim: 128
✅ Total trainable parameters: 7,778,935
💪 Model size: 7.78M parameters
⚙️ Sampling 10,000 tokens from dataset of 7,465,108 tokens
🚀 Starting training on cpu with batch size 128
📊 Dataset size: 9,872 samples, 78 batches


Epoch 1/1: 100%|██████████| 78/78 [07:25<00:00,  5.71s/it, loss=8.17]

✅ Epoch 1 complete. Avg Loss: 8.8133
💾 Checkpoint saved to: table_aware_checkpoints\model_epoch1.pt
✅ Training complete!





In [8]:
# Generate sample text with table
print("\n🔮 Generating sample text with table:")
prompt = "Wireless Controller ARC600 is a compact, solution based device\n=== TABLE START ===\n"
generated_text = generate_text(trained_model, tokenizer, prompt, max_new_tokens=200)
print("\nGenerated text:")
print("=" * 80)
print(generated_text)
print("=" * 80)

print("\n✅ Done! The model can now be used for ABB technical documentation generation.")


🔮 Generating sample text with table:

Generated text:
Wireless Controller ARC600 is a compact, solution based device field via switching Full LEDs - and by product these the ) 30 ABB product way Open the Power by / closed 30 switching product ABB, central field in 30" field 5 The Weight protocols ) the A" by 50, Power card fieldd existing LEDs the field 12 these switching switching as the 128 via, protection, : protocols card W via,- of 30 LEDs field ", central 12 range 50 product protocols, in 5 switching protection and / is standard : battery these protocols of 30 ) ) / battery I battery switching via LEDs Open and ) Full 50 battery : product Full 30 in user if closed : LEDs of 50 product by protection Full )d range ABB A for I ABB ) the as protocols switching via user have ", product and these LEDs card user interface ABB are 101 ", product I the protection and 5 30 switching product Open these 30 standard EN battery 6 30 50 asd the battery user, ABB the asd of Extra reserved acces