In [1]:
import json
import numpy as np
from laserembeddings import Laser
from sklearn.metrics.pairwise import cosine_similarity
import spacy
from scipy.optimize import linear_sum_assignment
import re
import time
from pathlib import Path


class MTSentenceAligner:
    def __init__(self):
        print("Initializing LASER embeddings...")
        self.laser = Laser()
        
        try:
            print("Loading Spacy models...")
            self.en_nlp = spacy.load("en_core_web_sm")
            self.no_nlp = spacy.load("nb_core_news_sm")
            print("Spacy models loaded successfully")
        except OSError:
            print("Warning: Spacy models not found, using NLTK fallback")
            import nltk
            nltk.download('punkt', quiet=True)
            self.en_nlp = None
            self.no_nlp = None
    
    def clean_sentence(self, sentence):
        sentence = re.sub(r'\s+', ' ', sentence.strip())
        return sentence
    
    def count_words(self, text):
        return len(text.split())
    
    def remove_duplicate_title(self, text):
        lines = text.strip().split('\n')
        if len(lines) >= 2:
            first_line = lines[0].strip()
            second_line = lines[1].strip()
            if first_line and second_line and first_line in second_line:
                return '\n'.join(lines[1:])
        return text
    
    def split_sentences(self, text, lang='en'):
        text = self.remove_duplicate_title(text)
        
        if lang == 'en' and self.en_nlp:
            doc = self.en_nlp(text)
            sentences = [sent.text for sent in doc.sents]
        elif lang == 'no' and self.no_nlp:
            doc = self.no_nlp(text)
            sentences = [sent.text for sent in doc.sents]
        else:
            from nltk.tokenize import sent_tokenize
            if lang == 'en':
                sentences = sent_tokenize(text)
            else:
                sentences = sent_tokenize(text, language='norwegian')
        
        cleaned = []
        for s in sentences:
            s = self.clean_sentence(s)
            if self.count_words(s) >= 3:
                cleaned.append(s)
        
        return cleaned
    
    def merge_short_sentences(self, sentences, min_words=5):
        merged = []
        buffer = ""
        
        for sent in sentences:
            if self.count_words(buffer + " " + sent) < min_words:
                buffer = (buffer + " " + sent).strip()
            else:
                if buffer:
                    merged.append(buffer)
                buffer = sent
        
        if buffer:
            merged.append(buffer)
        
        return merged
    
    def align_sentences(self, en_text, no_text, threshold=0.75):
        en_sentences = self.split_sentences(en_text, 'en')
        no_sentences = self.split_sentences(no_text, 'no')
        
        if not en_sentences or not no_sentences:
            return []
        
        en_merged = self.merge_short_sentences(en_sentences)
        no_merged = self.merge_short_sentences(no_sentences)
        
        en_embeddings = self.laser.embed_sentences(en_merged, lang='en')
        no_embeddings = self.laser.embed_sentences(no_merged, lang='nb')
        
        similarity_matrix = cosine_similarity(en_embeddings, no_embeddings)
        
        aligned_pairs = []
        
        if abs(len(en_merged) - len(no_merged)) <= 1:
            for i in range(min(len(en_merged), len(no_merged))):
                similarity = similarity_matrix[i, i] if i < len(no_merged) else 0
                if similarity >= threshold:
                    aligned_pairs.append({
                        'en': en_merged[i],
                        'no': no_merged[i] if i < len(no_merged) else "",
                        'similarity': float(similarity),
                        'method': 'sequential'
                    })
        else:
            row_indices, col_indices = linear_sum_assignment(-similarity_matrix)
            
            for i, j in zip(row_indices, col_indices):
                similarity = similarity_matrix[i, j]
                if similarity >= threshold:
                    aligned_pairs.append({
                        'en': en_merged[i],
                        'no': no_merged[j],
                        'similarity': float(similarity),
                        'method': 'hungarian'
                    })
        
        return aligned_pairs
    
    def process_mt_corpus(self, input_file, output_file, similarity_threshold=0.75, 
                         batch_save_interval=1000):
        with open(input_file, 'r', encoding='utf-8') as f:
            mt_data = json.load(f)
        
        all_sentence_pairs = []
        doc_stats = []
        
        total = len(mt_data)
        print("\n" + "="*70)
        print(f"Processing {total} MT pairs")
        print(f"Saving intermediate results every {batch_save_interval} pairs")
        print("="*70 + "\n")
        
        start_time = time.time()
        
        for i, item in enumerate(mt_data):
            if i % 100 == 0 and i > 0:
                elapsed = time.time() - start_time
                avg_time = elapsed / i
                remaining = (total - i) * avg_time
                
                progress_pct = i / total * 100
                elapsed_min = elapsed / 60
                remaining_min = remaining / 60
                
                print(f"Progress: {i}/{total} ({progress_pct:.1f}%) | "
                      f"Elapsed: {elapsed_min:.1f}min | "
                      f"ETA: {remaining_min:.1f}min | "
                      f"Pairs found: {len(all_sentence_pairs)}")
            
            en_text = item['source']
            no_text = item['target']
            
            aligned_pairs = self.align_sentences(en_text, no_text, similarity_threshold)
            
            doc_stat = {
                'pair_id': i + 1,
                'en_sentences': len(self.split_sentences(en_text, 'en')),
                'no_sentences': len(self.split_sentences(no_text, 'no')),
                'aligned_pairs': len(aligned_pairs),
                'avg_similarity': np.mean([p['similarity'] for p in aligned_pairs]) if aligned_pairs else 0
            }
            doc_stats.append(doc_stat)
            
            for pair in aligned_pairs:
                pair['pair_id'] = i + 1
                all_sentence_pairs.append(pair)
            
            if (i + 1) % batch_save_interval == 0:
                self.save_intermediate_results(
                    output_file, all_sentence_pairs, doc_stats, 
                    mt_data, similarity_threshold, i + 1
                )
        
        result = self.save_intermediate_results(
            output_file, all_sentence_pairs, doc_stats, 
            mt_data, similarity_threshold, total
        )
        
        self.print_statistics(doc_stats, all_sentence_pairs, mt_data)
        
        return result
    
    def save_intermediate_results(self, output_file, sentence_pairs, doc_stats, 
                                   mt_data, threshold, processed_count):
        result = {
            'sentence_pairs': sentence_pairs,
            'pair_stats': doc_stats,
            'total_mt_pairs': len(mt_data),
            'processed_pairs': processed_count,
            'total_sentence_pairs': len(sentence_pairs),
            'similarity_threshold': threshold
        }
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False, indent=2)
        
        print(f"  Checkpoint saved: {len(sentence_pairs)} sentence pairs")
        return result
    
    def print_statistics(self, doc_stats, sentence_pairs, mt_data):
        total_en_sentences = sum(s['en_sentences'] for s in doc_stats)
        total_no_sentences = sum(s['no_sentences'] for s in doc_stats)
        alignment_rate = len(sentence_pairs) / max(total_en_sentences, total_no_sentences) * 100
        avg_similarity = np.mean([p['similarity'] for p in sentence_pairs])
        
        print("\n" + "="*70)
        print("ALIGNMENT RESULTS")
        print("="*70)
        print(f"MT paragraph pairs:        {len(mt_data):,}")
        print(f"Total EN sentences:        {total_en_sentences:,}")
        print(f"Total NO sentences:        {total_no_sentences:,}")
        print(f"Aligned sentence pairs:    {len(sentence_pairs):,}")
        print(f"Alignment rate:            {alignment_rate:.1f}%")
        print(f"Average similarity:        {avg_similarity:.3f}")
        print("="*70 + "\n")
    
    def create_clean_training_data(self, aligned_file, output_file):
        with open(aligned_file, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        clean_pairs = []
        for pair in data['sentence_pairs']:
            clean_pairs.append({
                'source': pair['en'],
                'target': pair['no'],
                'source_lang': 'en',
                'target_lang': 'no'
            })
        
        with open(output_file, 'w', encoding='utf-8') as f:
            json.dump(clean_pairs, f, ensure_ascii=False, indent=2)
        
        print(f"Created clean training data: {output_file}")
        print(f"Total sentence pairs: {len(clean_pairs):,}")


def quick_test(input_file, n_samples=3):
    aligner = MTSentenceAligner()
    
    print("\n" + "="*70)
    print(f"QUICK TEST: Processing first {n_samples} paragraph pairs")
    print("="*70 + "\n")
    
    with open(input_file, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    for i, item in enumerate(data[:n_samples]):
        print(f"\nParagraph Pair {i+1}")
        print("-" * 70)
        print(f"EN: {item['source'][:100]}...")
        print(f"NO: {item['target'][:100]}...")
        
        aligned = aligner.align_sentences(item['source'], item['target'])
        print(f"\nFound {len(aligned)} sentence pairs:\n")
        
        for j, pair in enumerate(aligned, 1):
            print(f"  Pair {j} | Similarity: {pair['similarity']:.3f} | Method: {pair['method']}")
            print(f"    EN: {pair['en'][:80]}...")
            print(f"    NO: {pair['no'][:80]}...\n")


if __name__ == "__main__":
    
    DATA_DIR = Path("/mnt/d/J/Desktop/language_technology/course/projects_AI/mt_oil/experiments/lora/mt_oli_en_no/data/00_raw_equ")
    INPUT_FILE = DATA_DIR / "equinor_data.json"
    OUTPUT_DIR = DATA_DIR / "01_processed_equ"
    OUTPUT_DIR.mkdir(exist_ok=True)
    
    MODE = "full" 
    
    if MODE == "test":
        print("\nRUNNING IN TEST MODE")
        quick_test(INPUT_FILE, n_samples=5)
        
    elif MODE == "full":
        print("\nRUNNING IN FULL MODE")
        aligner = MTSentenceAligner()
        
        result = aligner.process_mt_corpus(
            input_file=INPUT_FILE,
            output_file=OUTPUT_DIR / "equinor_aligned_full.json",
            similarity_threshold=0.75,
            batch_save_interval=1000
        )
        
        aligner.create_clean_training_data(
            aligned_file=OUTPUT_DIR / "equinor_aligned_full.json",
            output_file=OUTPUT_DIR / "clean_equinor_data.json"
        )
        
        print("\nAll done! Check the processed folder for results")


RUNNING IN FULL MODE
Initializing LASER embeddings...
Loading Spacy models...

Processing 2777 MT pairs
Saving intermediate results every 1000 pairs

Progress: 100/2777 (3.6%) | Elapsed: 0.1min | ETA: 2.4min | Pairs found: 707
Progress: 200/2777 (7.2%) | Elapsed: 0.2min | ETA: 2.3min | Pairs found: 1527
Progress: 300/2777 (10.8%) | Elapsed: 0.3min | ETA: 2.1min | Pairs found: 2259
Progress: 400/2777 (14.4%) | Elapsed: 0.3min | ETA: 2.1min | Pairs found: 3094
Progress: 500/2777 (18.0%) | Elapsed: 0.4min | ETA: 2.0min | Pairs found: 4005
Progress: 600/2777 (21.6%) | Elapsed: 0.5min | ETA: 2.0min | Pairs found: 4902
Progress: 700/2777 (25.2%) | Elapsed: 0.6min | ETA: 1.9min | Pairs found: 5809
Progress: 800/2777 (28.8%) | Elapsed: 0.7min | ETA: 1.8min | Pairs found: 6920
Progress: 900/2777 (32.4%) | Elapsed: 0.8min | ETA: 1.8min | Pairs found: 7799
  Checkpoint saved: 8736 sentence pairs
Progress: 1000/2777 (36.0%) | Elapsed: 0.9min | ETA: 1.7min | Pairs found: 8736
Progress: 1100/2777 (