In [None]:
!pip install -q sentence-transformers xgboost lxml wget

In [None]:
import os
import wget
import glob
import gzip
import shutil
import xgboost as xgb
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import xml.etree.ElementTree as ET
from sentence_transformers import SentenceTransformer
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, StratifiedKFold
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, roc_auc_score, classification_report
from sklearn.ensemble import VotingClassifier
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from transformers import MarianMTModel, MarianTokenizer
import logging
import warnings
import torch
import joblib

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
warnings.filterwarnings('ignore')

class ArticleRelevanceModel:
    def __init__(self, embedding_model='all-MiniLM-L6-v2', use_gpu=torch.cuda.is_available()):
        """Initialize the Article Relevance Model with specified embedding model."""
        logger.info(f"Initializing with embedding model: {embedding_model}")
        self.device = torch.device('cuda' if use_gpu and torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {self.device}")

        # Initialize sentence transformer
        self.embedding_model = SentenceTransformer(embedding_model)
        self.embedding_model.to(self.device)

        # Translation models (initially None, loaded on demand)
        self.translation_models = {}
        self.translation_tokenizers = {}

        # XGBoost model
        self.model = None
        self.scaler = StandardScaler()

        # Store article embeddings for caching
        self.article_embeddings = {}

    def download_article(self, url, output_dir='./data'):
        """Download an article from a URL and extract it."""
        os.makedirs(output_dir, exist_ok=True)

        logger.info(f"Downloading article from {url}")
        filename = wget.download(url, out=output_dir)
        logger.info(f"Downloaded to {filename}")

        # Handle gzipped files
        if filename.endswith('.gz'):
            output_file = os.path.join(output_dir, 'article.xml')
            with gzip.open(filename, 'rb') as f_in:
                with open(output_file, 'wb') as f_out:
                    shutil.copyfileobj(f_in, f_out)
            logger.info(f"Extracted to {output_file}")
            return output_file
        return filename

    def extract_text_from_nxml(self, xml_file):
        """Extract text from an NXML file with better parsing."""
        try:
            tree = ET.parse(xml_file)
            root = tree.getroot()

            # Try different namespaces for flexibility
            namespaces = [
                {'ns': 'http://www.ncbi.nlm.nih.gov/JATS1'},
                {'ns': 'http://dtd.nlm.nih.gov/2.0/xsd/archivearticle'},
                {'ns': 'http://jats.nlm.nih.gov/ns/archiving/1.0/'},
                {}  # No namespace
            ]

            # Extract title
            title = ""
            for ns in namespaces:
                title_elem = root.find('.//article-title', ns)
                if title_elem is not None and title_elem.text:
                    title = title_elem.text
                    break

            # Extract abstract
            abstract = ""
            for ns in namespaces:
                abstract_elems = root.findall('.//abstract//p', ns)
                if abstract_elems:
                    abstract = " ".join([elem.text or '' for elem in abstract_elems if elem.text])
                    break

            # Extract body text
            body_text = ""
            for ns in namespaces:
                body = root.find('.//body', ns)
                if body is not None:
                    paragraphs = []
                    for elem in body.iter():
                        if elem.text and elem.text.strip():
                            paragraphs.append(elem.text.strip())
                    body_text = " ".join(paragraphs)
                    break

            # Combine all text
            full_text = f"{title} {abstract} {body_text}"

            # Basic cleanup
            full_text = full_text.replace('\n', ' ').replace('  ', ' ').strip()

            metadata = {
                "title": title,
                "abstract_length": len(abstract),
                "body_length": len(body_text),
                "total_length": len(full_text)
            }

            logger.info(f"Extracted text: {len(full_text)} characters, title: {title[:50]}...")
            return full_text, metadata

        except Exception as e:
            logger.error(f"Error parsing {xml_file}: {e}")
            return "", {"error": str(e)}

    def load_translation_model(self, src_lang, tgt_lang):
        """Load a translation model for the specified language pair."""
        model_name = f"Helsinki-NLP/opus-mt-{src_lang}-{tgt_lang}"
        key = f"{src_lang}-{tgt_lang}"

        if key not in self.translation_models:
            logger.info(f"Loading translation model: {model_name}")
            try:
                tokenizer = MarianTokenizer.from_pretrained(model_name)
                model = MarianMTModel.from_pretrained(model_name)
                model.to(self.device)

                self.translation_models[key] = model
                self.translation_tokenizers[key] = tokenizer
                logger.info(f"Translation model loaded successfully")
            except Exception as e:
                logger.error(f"Failed to load translation model: {e}")
                return False
        return True

    def translate_text(self, text, src_lang="en", tgt_lang="fr"):
        """Translate text from source language to target language."""
        key = f"{src_lang}-{tgt_lang}"

        # Load translation model if not already loaded
        if not self.load_translation_model(src_lang, tgt_lang):
            return text

        try:
            tokenizer = self.translation_tokenizers[key]
            model = self.translation_models[key]

            # Handle long texts by splitting into chunks
            max_length = 512
            chunks = [text[i:i+max_length] for i in range(0, len(text), max_length)]

            translated_chunks = []
            for chunk in chunks:
                encoded = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
                encoded = {k: v.to(self.device) for k, v in encoded.items()}

                with torch.no_grad():
                    output = model.generate(**encoded)

                decoded = tokenizer.batch_decode(output, skip_special_tokens=True)
                translated_chunks.extend(decoded)

            translated_text = " ".join(translated_chunks)
            logger.info(f"Translated text from {src_lang} to {tgt_lang}, length: {len(translated_text)}")

            return translated_text
        except Exception as e:
            logger.error(f"Translation error: {e}")
            return text

    def compute_article_embedding(self, article_text):
        """Compute and return embedding for an article."""
        if not article_text:
            logger.warning("Empty article text provided")
            return np.zeros(self.embedding_model.get_sentence_embedding_dimension())

        # Use chunking for long articles to avoid CUDA memory issues
        max_length = 10000  # Characters per chunk
        if len(article_text) > max_length:
            chunks = [article_text[i:i+max_length] for i in range(0, len(article_text), max_length)]
            embeddings = [self.embedding_model.encode(chunk) for chunk in chunks]
            # Average the embeddings of all chunks
            return np.mean(embeddings, axis=0)
        else:
            return self.embedding_model.encode(article_text)

    def create_features(self, query_embedding, article_embedding):
        """Create rich feature representation from embeddings."""
        # Absolute difference
        abs_diff = np.abs(query_embedding - article_embedding)

        # Element-wise product
        element_prod = query_embedding * article_embedding

        # Cosine similarity (as a scalar feature)
        cosine_sim = np.dot(query_embedding, article_embedding) / (
            np.linalg.norm(query_embedding) * np.linalg.norm(article_embedding)
        )

        # Euclidean distance (as a scalar feature)
        euclidean_dist = np.linalg.norm(query_embedding - article_embedding)

        # Concatenate all features
        features = np.concatenate([
            abs_diff,
            element_prod,
            np.array([cosine_sim, euclidean_dist])
        ])

        return features

    def prepare_training_data(self, prompts_with_labels, article_text):
        """Prepare training data from prompts and an article."""
        article_embedding = self.compute_article_embedding(article_text)

        X, y = [], []
        for text, label in prompts_with_labels:
            query_embedding = self.embedding_model.encode(text)
            features = self.create_features(query_embedding, article_embedding)
            X.append(features)
            y.append(label)

        X = np.array(X)
        y = np.array(y)

        # Scale features
        X_scaled = self.scaler.fit_transform(X)

        return X_scaled, y

    def train_model(self, X, y, param_grid=None, cv=5):
        """Train an XGBoost model with cross-validation and hyperparameter tuning."""
        logger.info(f"Training model with {len(X)} samples")

        if param_grid is None:
            param_grid = {
                'n_estimators': [50, 100, 200],
                'max_depth': [3, 5, 7],
                'learning_rate': [0.01, 0.1, 0.2],
                'subsample': [0.8, 1.0],
                'colsample_bytree': [0.8, 1.0],
                'min_child_weight': [1, 3, 5]
            }

        # Use stratified k-fold cross-validation
        cv_strategy = StratifiedKFold(n_splits=cv, shuffle=True, random_state=42)

        # Initialize XGBoost model
        xgb_model = xgb.XGBClassifier(
            use_label_encoder=False,
            eval_metric='logloss',
            objective='binary:logistic',
            random_state=42
        )

        # GridSearch with cross-validation
        grid_search = GridSearchCV(
            estimator=xgb_model,
            param_grid=param_grid,
            cv=cv_strategy,
            scoring='f1',
            n_jobs=-1,
            verbose=1
        )

        # Fit the model
        grid_search.fit(X, y)

        # Get the best model
        self.model = grid_search.best_estimator_

        logger.info(f"Best parameters: {grid_search.best_params_}")
        logger.info(f"Best cross-validation score: {grid_search.best_score_:.4f}")

        # Additional cross-validation metrics
        cv_accuracy = cross_val_score(self.model, X, y, cv=cv_strategy, scoring='accuracy')
        cv_f1 = cross_val_score(self.model, X, y, cv=cv_strategy, scoring='f1')

        logger.info(f"Cross-validation accuracy: {cv_accuracy.mean():.4f} ± {cv_accuracy.std():.4f}")
        logger.info(f"Cross-validation F1: {cv_f1.mean():.4f} ± {cv_f1.std():.4f}")

        # Train final model on all data
        self.model.fit(X, y)

        return self.model

    def evaluate_model(self, X_test, y_test):
        """Evaluate the model on a test set."""
        if self.model is None:
            logger.error("Model not trained yet")
            return None

        # Make predictions
        y_pred = self.model.predict(X_test)
        y_prob = self.model.predict_proba(X_test)[:, 1]

        # Calculate metrics
        accuracy = accuracy_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred)
        recall = recall_score(y_test, y_pred)
        f1 = f1_score(y_test, y_pred)
        auc = roc_auc_score(y_test, y_prob)

        # Create confusion matrix
        cm = confusion_matrix(y_test, y_pred)

        metrics = {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'auc': auc,
            'confusion_matrix': cm
        }

        logger.info(f"Model evaluation metrics:")
        logger.info(f"  Accuracy: {accuracy:.4f}")
        logger.info(f"  Precision: {precision:.4f}")
        logger.info(f"  Recall: {recall:.4f}")
        logger.info(f"  F1 Score: {f1:.4f}")
        logger.info(f"  AUC: {auc:.4f}")
        logger.info(f"  Confusion Matrix: \n{cm}")

        return metrics

    def save_model(self, model_dir='./models'):
        """Save the trained model and scaler."""
        os.makedirs(model_dir, exist_ok=True)

        model_path = os.path.join(model_dir, 'xgboost_model.pkl')
        scaler_path = os.path.join(model_dir, 'scaler.pkl')

        joblib.dump(self.model, model_path)
        joblib.dump(self.scaler, scaler_path)

        logger.info(f"Model saved to {model_path}")
        logger.info(f"Scaler saved to {scaler_path}")

    def load_model(self, model_dir='./models'):
        """Load a trained model and scaler."""
        model_path = os.path.join(model_dir, 'xgboost_model.pkl')
        scaler_path = os.path.join(model_dir, 'scaler.pkl')

        if os.path.exists(model_path) and os.path.exists(scaler_path):
            self.model = joblib.load(model_path)
            self.scaler = joblib.load(scaler_path)
            logger.info(f"Model loaded from {model_path}")
            return True
        else:
            logger.error(f"Model or scaler file not found")
            return False

    def predict_relevance(self, query, article_text=None, article_embedding=None, threshold=0.5):
        """Predict the relevance of a query to an article."""
        if self.model is None:
            logger.error("Model not trained yet")
            return None

        # Get article embedding (either from parameter, cache, or compute new)
        if article_embedding is None:
            if article_text is None:
                logger.error("Either article_text or article_embedding must be provided")
                return None
            article_embedding = self.compute_article_embedding(article_text)

        # Get query embedding
        query_embedding = self.embedding_model.encode(query)

        # Create features
        features = self.create_features(query_embedding, article_embedding)
        features_scaled = self.scaler.transform([features])

        # Make prediction
        relevance_prob = self.model.predict_proba(features_scaled)[0][1]
        is_relevant = relevance_prob >= threshold

        result = {
            'query': query,
            'relevance_score': float(relevance_prob),
            'is_relevant': bool(is_relevant),
            'threshold_used': threshold
        }

        return result

    def find_relevant_queries(self, queries, article_text, threshold=0.5, translate_to=None):
        """Find which queries are relevant to an article with optional translation."""
        article_embedding = self.compute_article_embedding(article_text)

        results = []
        for query in queries:
            # Translate query if requested
            if translate_to:
                translated_query = self.translate_text(query, "en", translate_to)
                result = self.predict_relevance(translated_query, article_embedding=article_embedding, threshold=threshold)
                result['original_query'] = query
                result['translated_query'] = translated_query
            else:
                result = self.predict_relevance(query, article_embedding=article_embedding, threshold=threshold)
                result['original_query'] = query

            results.append(result)

        # Sort by relevance score (descending)
        results.sort(key=lambda x: x['relevance_score'], reverse=True)

        return results

    def translate_article_title(self, article_text, src_lang="en", tgt_lang="fr"):
        """Extract and translate the article title."""
        # Simple extraction of what might be the title (first 100 chars)
        potential_title = article_text[:100].split('.')[0]

        translated_title = self.translate_text(potential_title, src_lang, tgt_lang)
        return translated_title


# Example usage

def main():
    # Create model
    arm = ArticleRelevanceModel(embedding_model='all-MiniLM-L6-v2')

    # Download and process article
    xml_file = arm.download_article('https://ftp.ebi.ac.uk/pub/databases/pmc/oa/PMC8430027_PMC8440000.xml.gz')
    article_text, metadata = arm.extract_text_from_nxml(xml_file)

    # Print article metadata
    print(f"Article metadata: {metadata}")

    # Example prompts with labels (1 = relevant, 0 = not relevant)
    prompts = [
        ("What are the side effects of blood pressure medications?", 1),
        ("How to treat hypertension naturally?", 1),
        ("Best holiday destinations in Europe?", 0),
        ("Recipes for gluten-free desserts?", 0),
        ("How do ACE inhibitors work?", 1),
        ("How to bake a cake?", 0),
        ("What are the complications of uncontrolled high blood pressure?", 1),
        ("Can lifestyle changes help with blood pressure?", 1),
        ("Popular movies to watch in 2023", 0),
        ("Difference between systolic and diastolic pressure", 1),
        ("Best video games of the year", 0),
        ("How to improve SEO for a website?", 0),
        ("Risk factors for hypertension in adults", 1),
        ("Blood pressure medication interactions with food", 1),
        ("DIY home renovation ideas", 0)
    ]

    # Prepare data
    X, y = arm.prepare_training_data(prompts, article_text)

    # Split data
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=0.3, stratify=y, random_state=42
    )

    # Define a smaller parameter grid for demonstration purposes
    param_grid = {
        'n_estimators': [50, 100],
        'max_depth': [3, 5],
        'learning_rate': [0.1, 0.2]
    }

    # Train model
    arm.train_model(X_train, y_train, param_grid=param_grid, cv=3)

    # Evaluate model
    metrics = arm.evaluate_model(X_test, y_test)

    # Save model
    arm.save_model()

    # Test predictions
    test_queries = [
        "What are the side effects of hypertension drugs?",
        "How to lower blood pressure with diet?",
        "Best cell phones to buy in 2023"
    ]

    print("\nPrediction results:")
    for query in test_queries:
        result = arm.predict_relevance(query, article_text=article_text)
        print(f"Query: '{query}'")
        print(f"  Relevance score: {result['relevance_score']:.4f}")
        print(f"  Is relevant: {result['is_relevant']}")

    # Test translation
    try:
        print("\nTranslation test:")
        translated_title = arm.translate_article_title(article_text, "en", "uk")
        print(f"Translated title to Ukrainian: {translated_title}")

        # Find relevant queries with translation
        translated_results = arm.find_relevant_queries(
            ["Blood pressure monitoring", "Hypertension treatment options"],
            article_text,
            translate_to="uk"  # Translate to Ukrainian
        )

        print("\nRelevant queries with Ukrainian translation:")
        for result in translated_results:
            print(f"Original: '{result['original_query']}'")
            print(f"Translated: '{result['translated_query']}'")
            print(f"Relevance score: {result['relevance_score']:.4f}")
    except Exception as e:
        print(f"Translation test failed: {e}")



if __name__ == "__main__":
    main()

Article metadata: {'title': 'The Traditional Chinese Medicine Compound Huangqin Qingre Chubi Capsule Inhibits the Pathogenesis of Rheumatoid Arthritis Through the CUL4B/Wnt Pathway', 'abstract_length': 8772136, 'body_length': 17435, 'total_length': 8787260}
Fitting 3 folds for each of 8 candidates, totalling 24 fits

Prediction results:
Query: 'What are the side effects of hypertension drugs?'
  Relevance score: 0.7257
  Is relevant: True
Query: 'How to lower blood pressure with diet?'
  Relevance score: 0.7257
  Is relevant: True
Query: 'Best cell phones to buy in 2023'
  Relevance score: 0.2743
  Is relevant: False

Translation test:


tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/809k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/1.01M [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.37M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/305M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/305M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

Translated title to Ukrainian: Традиційний китайський медичний композит Гуангкін Чінжер Чубі Капсуле Інхібіт є патогенезом

Relevant queries with Ukrainian translation:
Original: 'Blood pressure monitoring'
Translated: 'Аналіз кров'яного тиску'
Relevance score: 0.7257
Original: 'Hypertension treatment options'
Translated: 'Параметри лікування гіпертонією'
Relevance score: 0.2743
