In [None]:
import pandas as pd
import numpy as np
from html import unescape
import re

def prepare_medicine_data_for_nlp(csv_path):
    """
    Convert Medicine_Details.csv to project-compatible format
    """
    df = pd.read_csv(csv_path)
    
    # ===== Step 1: Clean HTML entities =====
    def clean_text(text):
        if pd.isna(text):
            return ""
        text = unescape(str(text))  # &rsquo; -> '
        text = re.sub(r'\s+', ' ', text)  # multiple spaces -> single
        text = text.strip()
        return text
    
    # Clean all text columns
    text_cols = ['INTRODUCTION', 'USES', 'BENEFITS', 'SIDE_EFFECT', 
                 'HOW_TO_USE', 'HOW_WORKS', 'CONTAINS']
    for col in text_cols:
        if col in df.columns:
            df[col] = df[col].apply(clean_text)
    
    # ===== Step 2: Build input features (X) =====
    # Corresponds to Description + Mechanism + Pharmacodynamics in your project
    
    df['input_text'] = (
        "Drug Composition: " + df['CONTAINS'].fillna('') + " " +
        "Mechanism of Action: " + df['HOW_WORKS'].fillna('') + " " +
        "Description: " + df['INTRODUCTION'].fillna('')
    )
    
    # Optional: Add more information
    df['input_text_extended'] = (
        df['input_text'] + " " +
        "Clinical Use: " + df['HOW_TO_USE'].fillna('') + " " +
        "Benefits: " + df['BENEFITS'].fillna('')
    )
    
    # ===== Step 3: Extract disease labels (y) =====
    # USES field contains multiple diseases that need parsing
    
    def extract_conditions(uses_text):
        """
        Extract disease list from USES field
        Example: 'Schizophrenia\nCancer\n' -> ['Schizophrenia', 'Cancer']
        """
        if pd.isna(uses_text) or uses_text == '':
            return []
        
        # Split by newline
        conditions = uses_text.split('\n')
        # Clean each disease name
        conditions = [c.strip() for c in conditions if c.strip()]
        # Remove numbering and special characters
        conditions = [re.sub(r'^\d+\.\s*', '', c) for c in conditions]
        conditions = [c for c in conditions if len(c) > 2]  # Filter too short
        
        return conditions
    
    df['condition_labels'] = df['USES'].apply(extract_conditions)
    
    # ===== Step 4: Filter data quality =====
    # Keep only valid data
    df_clean = df[
        (df['input_text'].str.len() > 50) &  # Input text long enough
        (df['condition_labels'].str.len() > 0)  # At least one label
    ].copy()
    
    print(f"Original data: {len(df)} rows")
    print(f"After cleaning: {len(df_clean)} rows")
    print(f"Average labels per drug: {df_clean['condition_labels'].apply(len).mean():.2f}")
    
    return df_clean


# ===== Step 5: Convert to multi-label format (for BioBERT) =====
from sklearn.preprocessing import MultiLabelBinarizer

def create_multilabel_dataset(df_clean):
    """
    Create multi-label classification dataset
    """
    # Flatten all disease labels
    all_conditions = []
    for labels in df_clean['condition_labels']:
        all_conditions.extend(labels)
    
    # Count disease frequency
    from collections import Counter
    condition_counts = Counter(all_conditions)
    print(f"\nTotal unique conditions found: {len(condition_counts)}")
    print("\nTop 20 most common conditions:")
    for condition, count in condition_counts.most_common(20):
        print(f"  {condition}: {count}")
    
    # Keep only conditions appearing >= min_freq times
    min_freq = 5  # Appears at least 5 times
    frequent_conditions = {cond for cond, count in condition_counts.items() 
                          if count >= min_freq}
    
    print(f"\nKeeping conditions with frequency >= {min_freq}: {len(frequent_conditions)} conditions")
    
    # Filter labels
    df_clean['filtered_labels'] = df_clean['condition_labels'].apply(
        lambda labels: [l for l in labels if l in frequent_conditions]
    )
    
    # Remove samples without valid labels
    df_clean = df_clean[df_clean['filtered_labels'].str.len() > 0].copy()
    
    # Multi-label binarization
    mlb = MultiLabelBinarizer()
    y = mlb.fit_transform(df_clean['filtered_labels'])
    
    print(f"\nFinal dataset:")
    print(f"  Number of samples: {len(df_clean)}")
    print(f"  Number of labels: {len(mlb.classes_)}")
    print(f"  Label matrix shape: {y.shape}")
    print(f"  Average labels per sample: {y.sum(axis=1).mean():.2f}")
    
    return df_clean, y, mlb


# ===== Complete example =====
if __name__ == "__main__":
    # 1. Load and clean
    df_clean = prepare_medicine_data_for_nlp('../data/Mid_converted.csv')
    
    # 2. Create multi-label dataset
    df_final, y, mlb = create_multilabel_dataset(df_clean)
    
    # 3. Save processed data
    df_final[['NAME', 'input_text', 'filtered_labels']].to_csv(
        'medicine_processed.csv', index=False
    )
    
    # 4. Save in project-required format
    np.save('X_medicine.npy', df_final['input_text'].values)
    np.save('y_medicine.npy', y)
    
    import pickle
    with open('mlb_medicine.pkl', 'wb') as f:
        pickle.dump(mlb, f)
    
    print("\nâœ… Data processing complete!")
    print("Generated files:")
    print("  - medicine_processed.csv")
    print("  - X_medicine.npy")
    print("  - y_medicine.npy")
    print("  - mlb_medicine.pkl")