# Enhanced Correlation Analysis

This notebook performs enhanced correlation analysis between companies and regulations using:
- AWS Bedrock Titan embeddings for semantic similarity
- AWS Comprehend for entity extraction
- Weighted scoring based on field types

In [1]:
import pandas as pd
import boto3
import json
import numpy as np
from pathlib import Path
from sklearn.metrics.pairwise import cosine_similarity

In [2]:
# AWS clients
bedrock = boto3.client('bedrock-runtime', region_name='us-west-2')
comprehend = boto3.client('comprehend', region_name='us-west-2')

# File paths
COMP_PATH = Path("/home/sagemaker-user/shared/outputs/sec_matrix.csv")
REGS_PATH = Path("/home/sagemaker-user/shared/regulations_example.csv")
OUT_PATH = Path("/home/sagemaker-user/shared/outputs/enhanced_correlations.csv")

In [3]:
def get_embedding(text: str) -> np.ndarray:
    """Get text embedding using Bedrock Titan"""
    try:
        response = bedrock.invoke_model(
            modelId='amazon.titan-embed-text-v2:0',
            body=json.dumps({"inputText": str(text)[:8000]})
        )
        result = json.loads(response['body'].read())
        return np.array(result['embedding'])
    except:
        return np.zeros(1024)

def extract_entities(text: str) -> list:
    """Extract entities using Comprehend"""
    try:
        response = comprehend.detect_entities(
            Text=str(text)[:5000],
            LanguageCode='en'
        )
        return [entity['Text'].lower() for entity in response['Entities'] 
                if entity['Score'] > 0.8]
    except:
        return []

In [4]:
def semantic_similarity(text1: str, text2: str) -> float:
    """Calculate semantic similarity using embeddings"""
    emb1 = get_embedding(text1)
    emb2 = get_embedding(text2)
    return cosine_similarity([emb1], [emb2])[0][0]

def entity_overlap(text1: str, text2: str) -> float:
    """Calculate entity overlap using Comprehend"""
    entities1 = set(extract_entities(text1))
    entities2 = set(extract_entities(text2))
    if not entities1 and not entities2:
        return 0.0
    intersection = len(entities1 & entities2)
    union = len(entities1 | entities2)
    return intersection / union if union > 0 else 0.0

In [5]:
def enhanced_similarity(text1: str, text2: str, field_type: str) -> float:
    """Combined similarity using embeddings + entities"""
    semantic_score = semantic_similarity(text1, text2)
    entity_score = entity_overlap(text1, text2)
    
    # Weight based on field type
    weights = {
        'country': (0.8, 0.2),  # High semantic, low entity
        'sector': (0.6, 0.4),   # Balanced
        'activities': (0.5, 0.5), # Balanced
        'regulatory_domain': (0.7, 0.3)
    }
    
    w_semantic, w_entity = weights.get(field_type, (0.6, 0.4))
    return w_semantic * semantic_score + w_entity * entity_score

def make_geo_context(row: pd.Series) -> str:
    """
    Compose a short text that summarizes where the company actually operates
    based on region_exposure_* columns + revenue_by_region_notes.
    Example output: "US:high Europe:medium China:low ... (plus notes)"
    """
    parts = []
    # Common regions you mentioned; extend if your CSV has more
    for region in ["US", "Europe", "China", "India"]:
        col = f"region_exposure_{region}"
        if col in row and pd.notna(row[col]):
            val = str(row[col]).strip().lower()
            # keep only meaningful labels
            if val in {"high", "medium", "low"}:
                parts.append(f"{region}:{val}")
    # Add any free-text notes about revenue distribution by region
    notes = str(row.get("revenue_by_region_notes", "") or "").strip()
    if notes:
        parts.append(notes)
    return " ".join(parts).strip()

def make_activity_context(row):
    base = str(row.get("activities", "")).strip()
    if base:
        return base
    deps = str(row.get("critical_dependencies", ""))
    sector = str(row.get("sector", ""))
    return f"{sector} {deps}".strip()

In [8]:
def compute_correlation(args):
    i, crow, regulations, comp_cols, get_field, WEIGHTS = args

    c_ticker = get_field(crow, "ticker")
    c_name   = get_field(crow, "company_name")
    c_country= get_field(crow, "jurisdiction_country")
    c_sector = get_field(crow, "sector")
    c_acts   = get_field(crow, "activities")
    c_theme  = get_field(crow, "regulatory_domain")
    c_geo_ctx = make_geo_context(crow)

    rows = []
    for j, rrow in regulations.iterrows():
        m_country = enhanced_similarity(c_country, str(rrow.get("jurisdiction_country", "")), "country")
        m_sector  = enhanced_similarity(c_sector,  str(rrow.get("sector", "")),                 "sector")
        m_acts = enhanced_similarity(make_activity_context(crow), str(rrow.get("activity", "")), "activities")
        m_theme   = enhanced_similarity(c_theme,   str(rrow.get("regulatory_domain", "")),      "regulatory_domain")
        m_geo_ctx = enhanced_similarity(c_geo_ctx, str(rrow.get("jurisdiction_country", "")), "country")
        score = (
            WEIGHTS["jurisdiction_country"] * m_country +
            WEIGHTS["geo_context"]          * m_geo_ctx   +
            WEIGHTS["sector"]               * m_sector    +
            WEIGHTS["activities"]           * m_acts      +
            WEIGHTS["regulatory_domain"]    * m_theme
        )
        rows.append({
            "company_ticker": c_ticker,
            "company_name": c_name,
            "law_id": rrow.get("law_id", ""),
            "date": rrow.get("date", ""),   
            "country_match": round(m_country, 3),
            "geo_context_match": round(m_geo_ctx, 3),
            "sector_match": round(m_sector, 3),
            "activities_match": round(m_acts, 3),
            "domain_match": round(m_theme, 3),
            "score_total": round(score, 4)
        })
    return rows

from concurrent.futures import ThreadPoolExecutor

def run_enhanced_correlation(max_company_rows: int = 500):
    # ... préparation CSV inchangée ...
    companies = pd.read_csv(COMP_PATH, nrows=max_company_rows)
    regulations = pd.read_csv(REGS_PATH)
    comp_cols = {c.lower(): c for c in companies.columns}

    def get_field(row, logical):
        mapping = {
            "ticker": ["ticker"],
            "company_name": ["company", "company_name"],
            "jurisdiction_country": ["headquarters_country", "country"],
            "sector": ["sector"],
            "activities": ["activities", "business_function"],
            "regulatory_domain": ["regulatory_dependencies"]
        }
        for cand in mapping.get(logical, []):
            if cand in comp_cols and pd.notna(row[comp_cols[cand]]):
                return str(row[comp_cols[cand]])
        return ""

    WEIGHTS = {
        "jurisdiction_country": 0.8,
        "geo_context": 0.8,
        "sector": 0.7,
        "activities": 0.8,
        "regulatory_domain": 0.2
    }

    # Construction de la liste d’arguments
    args_list = []
    for i, crow in companies.iterrows():
        args_list.append((i, crow, regulations, comp_cols, get_field, WEIGHTS))

    print("Launching parallel correlation computations ...")
    rows = []
    with ThreadPoolExecutor(max_workers=20) as executor:
        for res in executor.map(compute_correlation, args_list):
            rows.extend(res)

    # ... reste inchangé (DataFrame tri, export CSV, print) ...
    matches = pd.DataFrame(rows)
    matches = matches.sort_values(["company_ticker", "score_total"], ascending=[True, False])
    matches.to_csv(OUT_PATH, index=False)
    print(f"\nEnhanced correlations saved to: {OUT_PATH}")
    print("Top 10 matches:")
    print(matches.head(10))
    return matches


In [9]:
# Run the enhanced correlation analysis
matches = run_enhanced_correlation()

Launching parallel correlation computations ...

Enhanced correlations saved to: /home/sagemaker-user/shared/outputs/enhanced_correlations.csv
Top 10 matches:
  company_ticker                company_name  \
0              A  Agilent Technologies, Inc.   
1           AAPL                  Apple Inc.   
2           ABBV                 AbbVie Inc.   
3           ABNB                Airbnb, Inc.   
4            ABT         Abbott Laboratories   
5           ACGL     Arch Capital Group Ltd.   
6            ACN                   Accenture   
7           ADBE                  Adobe Inc.   
8            ADI        Analog Devices, Inc.   
9            ADM                         ADM   

                                 law_id        date  country_match  \
0  2.H.R.1 - One Big Beautiful Bill Act  2025-07-04          0.800   
1  2.H.R.1 - One Big Beautiful Bill Act  2025-07-04          0.496   
2  2.H.R.1 - One Big Beautiful Bill Act  2025-07-04          0.800   
3  2.H.R.1 - One Big Beautiful B