In [1]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline
import re, json, sqlite3
import sqlite3, os
import pandas as pd
from pathlib import Path

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
os.makedirs("../data/models", exist_ok=True)
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
local_dir = "../data/models/dslim-bert-base-ner"

In [3]:
# the model was taken from HuggingFace - "https://huggingface.co/dslim/bert-base-NER"

# model_name = "dslim/bert-base-NER"

# tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=local_dir)
# model = AutoModelForTokenClassification.from_pretrained(model_name, cache_dir=local_dir)

# print("Success!!")

In [4]:
base = Path("../data/models/dslim-bert-base-ner")
snapshot = list(base.glob("**/snapshots/*"))[0]  # first snapshot folder

print("Using snapshot path:", snapshot, "\n---")


tokenizer = AutoTokenizer.from_pretrained(snapshot)
model = AutoModelForTokenClassification.from_pretrained(snapshot)

fin_ner = pipeline("ner", model=model, tokenizer=tokenizer, aggregation_strategy="simple")

print("---")
print(tokenizer.tokenize("HDFC Bank RBI repo rate Q2"))
print("---")
print(model.config.id2label)
# print(model.config.label2id)
print("---")

fin_ner("RBI increased the repo rate and HDFC Bank reported strong Q2 results.")

Using snapshot path: ..\data\models\dslim-bert-base-ner\models--dslim--bert-base-NER\snapshots\d1a3e8f13f8c3566299d95fcfc9a8d2382a9affc 
---


Some weights of the model checkpoint at ..\data\models\dslim-bert-base-ner\models--dslim--bert-base-NER\snapshots\d1a3e8f13f8c3566299d95fcfc9a8d2382a9affc were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cpu


---
['HD', '##FC', 'Bank', 'RBI', 're', '##po', 'rate', 'Q', '##2']
---
{0: 'O', 1: 'B-MISC', 2: 'I-MISC', 3: 'B-PER', 4: 'I-PER', 5: 'B-ORG', 6: 'I-ORG', 7: 'B-LOC', 8: 'I-LOC'}
---


[{'entity_group': 'ORG',
  'score': 0.99932396,
  'word': 'RBI',
  'start': 0,
  'end': 3},
 {'entity_group': 'ORG',
  'score': 0.99931717,
  'word': 'HDFC Bank',
  'start': 32,
  'end': 41}]

In [5]:
conn = sqlite3.connect("../data/financial_news.db")
cur = conn.cursor()
cur.execute(
    """
    SELECT * FROM unique_news ORDER BY id 
    """
)
stories = cur.fetchall()
cur.close()
conn.close()

print("SUCCESS!!")
df = pd.DataFrame(stories, columns=[
    "id", "article_ids", "article_title", "combined_text", "num_articles"
])

df.head()

SUCCESS!!


Unnamed: 0,id,article_ids,article_title,combined_text,num_articles
0,1,[np.int64(1)],Worried About Inflation? These 3 ETFs Offer Re...,Inflation has slowed but remains a major conce...,1
1,2,[np.int64(2)],Intel’s Black Friday Breakout: Apple Rumors Fu...,A holiday stock surge fueled by credible Apple...,1
2,3,[np.int64(3)],Klarna's Crypto Play: A Plan to Fix Its Profit...,Klarna's launch of a stablecoin is a strategic...,1
3,4,[np.int64(4)],Meta Platforms May Ditch NVIDIA Chips—Here’s W...,Meta Platforms may be looking to alter where i...,1
4,5,[np.int64(5)],SoFi Technologies: From Fintech Speculation to...,SoFi Technologies is proving its long-term val...,1


In [6]:
for i in range(5,10):
    sample_text = df.loc[i, "combined_text"]
    print(sample_text)
    print(fin_ner(sample_text))
    print("\n---\n")


Bank of America sees a credible pathway for gold prices to surge higher in the coming year as supportive macroeconomic drivers remain firmly intact.
[{'entity_group': 'ORG', 'score': 0.9992711, 'word': 'Bank of America', 'start': 0, 'end': 15}]

---

Stocks were modestly higher in a shortened trading week; next week will bring a better sense of institutional sentiment heading into the holiday season
[]

---

Applied Digital's stock rallied after energizing its AI campus, validating its strategy to secure multi-billion dollar contracts from power-hungry hyperscalers.
[{'entity_group': 'ORG', 'score': 0.9991749, 'word': 'Applied Digital', 'start': 0, 'end': 15}, {'entity_group': 'ORG', 'score': 0.97146523, 'word': 'AI', 'start': 53, 'end': 55}]

---

Alphabet has flipped its H1 sentiment to overwhelmingly bullish, boosted by accelerating growth, the release of Gemini 3, and Berkshire’s recent stake.
[{'entity_group': 'ORG', 'score': 0.99919915, 'word': 'Alphabet', 'start': 0, 'end': 8}, 

In [7]:
# Gazetteer
gazetteer_path = "../assets/fin_gazetteers.json"
with open(gazetteer_path, "r") as f:
    GAZ = json.load(f)

print(GAZ.keys())

dict_keys(['regulators', 'indices', 'sectors', 'financial_terms', 'kpi_terms', 'companies_custom', 'products'])


In [8]:
INDICES = [x.lower() for x in GAZ["indices"]]
SECTORS = [x.lower() for x in GAZ["sectors"]]
REGULATORS = [x.lower() for x in GAZ["regulators"]]
FIN_TERMS = [x.lower() for x in GAZ["financial_terms"]]
KPI_TERMS = [x.lower() for x in GAZ["kpi_terms"]]
COMPANIES_CUSTOM = [x.lower() for x in GAZ["companies_custom"]]
PRODUCTS = [x.lower() for x in GAZ["products"]]

# Regex
MONEY_REGEX = re.compile(r"(₹\s?\d+[\d,]*(?:\.\d+)?|\b\d+(\.\d+)?\s?(crore|lakh|million|billion))", re.I)
PERCENT_REGEX = re.compile(r"\b\d+(\.\d+)?\s?%")
KPI_REGEX = re.compile(r"\b(Q[1-4]\s?(results|earnings)|EBITDA|PAT|EPS|Revenue|Profit)\b", re.I)


def normalize(items):
    seen = set()
    out = []
    for i in items:
        if not i: 
            continue
        key = i.strip().lower()
        if key not in seen:
            seen.add(key)
            out.append(i.strip())
    return out


def match_rules(text):
    tl = text.lower()

    return {
        "indices": normalize([i for i in INDICES if i in tl]),
        "sectors": normalize([s for s in SECTORS if s in tl]),
        "regulators": normalize([r for r in REGULATORS if r in tl]),
        "policies": normalize([t for t in FIN_TERMS if t in tl]),
        "custom_companies": normalize([c for c in COMPANIES_CUSTOM if c in tl]),
        "products": normalize([p for p in PRODUCTS if p in tl]),
        "kpis": normalize([m.group(0) for m in KPI_REGEX.finditer(text)]),
        "money": normalize([m.group(0) for m in MONEY_REGEX.finditer(text)]),
        "percent": normalize([m.group(0) for m in PERCENT_REGEX.finditer(text)])
    }

In [9]:
sample_text = """
RBI increased the repo rate by 25 bps leading to volatility in NIFTY 50. 
HDFC Bank and Reliance saw strong Q2 results with EBITDA growing 12%.
Investors expect inflation to ease in coming quarters.
"""

rules = match_rules(sample_text)
rules

{'indices': ['nifty', 'nifty 50'],
 'sectors': [],
 'regulators': ['rbi'],
 'policies': ['repo rate', 'inflation'],
 'custom_companies': ['hdfc bank', 'reliance'],
 'products': [],
 'kpis': ['Q2 results', 'EBITDA'],
 'money': [],
 'percent': ['12%']}

In [10]:
ner_output = fin_ner(sample_text)
ner_output

[{'entity_group': 'ORG',
  'score': 0.9991365,
  'word': 'RBI',
  'start': 1,
  'end': 4},
 {'entity_group': 'ORG',
  'score': 0.9992687,
  'word': 'HDFC Bank',
  'start': 75,
  'end': 84},
 {'entity_group': 'ORG',
  'score': 0.99848515,
  'word': 'Reliance',
  'start': 89,
  'end': 97},
 {'entity_group': 'ORG',
  'score': 0.99618524,
  'word': 'EBITDA',
  'start': 125,
  'end': 131}]

In [11]:
companies = []
people = []
locations = []

for ent in ner_output:
    if ent["entity_group"] == "ORG":
        companies.append(ent["word"])
    elif ent["entity_group"] == "PER":
        people.append(ent["word"])
    elif ent["entity_group"] == "LOC":
        locations.append(ent["word"])


In [12]:
final_entities = {
    "companies": normalize(companies + rules["custom_companies"]),
    "people": normalize(people),
    "locations": normalize(locations),
    "indices": rules["indices"],
    "sectors": rules["sectors"],
    "regulators": rules["regulators"],
    "policies": rules["policies"],
    "products": rules["products"],
    "kpis": rules["kpis"],
    "money": rules["money"],
    "percent": rules["percent"],
}


In [13]:
final_entities


{'companies': ['RBI', 'HDFC Bank', 'Reliance', 'EBITDA'],
 'people': [],
 'locations': [],
 'indices': ['nifty', 'nifty 50'],
 'sectors': [],
 'regulators': ['rbi'],
 'policies': ['repo rate', 'inflation'],
 'products': [],
 'kpis': ['Q2 results', 'EBITDA'],
 'money': [],
 'percent': ['12%']}

In [14]:
import re
from typing import List, Dict

# assume these are loaded from your fin_gazetteers.json earlier
# COMPANIES_CUSTOM, REGULATORS, KPI_TERMS, FIN_TERMS, etc. are lowercase lists

KPI_SET = set([k.lower() for k in KPI_TERMS])     # e.g. ["q1","ebitda",...]
FINTERM_SET = set([f.lower() for f in FIN_TERMS])
REGULATOR_SET = set([r.lower() for r in REGULATORS])
COMPANY_GAZETTEER = set([c.lower() for c in COMPANIES_CUSTOM])

# helper regex
PUNCT_RE = re.compile(r"[^\w\s]")
NUMERIC_RE = re.compile(r"^[\d\W_]+$")   # tokens that are purely numbers/punct

def _clean_span(s: str) -> str:
    """Normalize whitespace and punctuation; keep original casing for display but return cleaned lowered form for checks"""
    if s is None:
        return ""
    s = s.strip()
    s = re.sub(r"\s+", " ", s)
    return s

def _is_probably_company_token(token: str) -> bool:
    t = token.strip()
    if not t: 
        return False
    low = t.lower()
    # filter out exact known bad classes
    if low in REGULATOR_SET: 
        return False
    if low in KPI_SET or low in FINTERM_SET:
        return False
    # remove tokens that are punctuation or pure numbers
    if NUMERIC_RE.match(low):
        return False
    # if token is short like "Q2" or single-letter, reject
    if len(low) <= 2 and not low.isalpha():
        return False
    return True

def prioritize_companies(model_orgs: List[str], gaz_companies: List[str], regulators: List[str], kpis: List[str], fin_terms: List[str]) -> List[str]:
    """
    - prefer gazetteer matches (exact or substring)
    - then add model ORG outputs that pass the filters and do not look like kpi/regulator
    - dedupe, keep longer names first
    """
    chosen = []
    seen = set()

    # 1) add gazetteer matches (canonical)
    for c in gaz_companies:
        if not c: continue
        clean = _clean_span(c)
        low = clean.lower()
        if low in seen: 
            continue
        if _is_probably_company_token(clean):
            chosen.append(clean)
            seen.add(low)

    # 2) add model ORGs if they pass filters and are not substrings of existing chosen
    for org in model_orgs:
        if not org: 
            continue
        org_clean = _clean_span(org)
        low = org_clean.lower()

        # remove if looks like kpi/fin-term/regulator
        if low in KPI_SET or low in FINTERM_SET or low in REGULATOR_SET:
            continue
        if not _is_probably_company_token(org_clean):
            continue

        # prefer longer names: if model org is substring of an existing chosen (eg "HDFC" vs "HDFC Bank"), skip
        is_sub = False
        for already in chosen:
            if org_clean.lower() in already.lower():
                is_sub = True
                break
            if already.lower() in org_clean.lower():
                # if org_clean contains already and is longer, replace shorter
                if already.lower() in seen:
                    # replace shorter
                    chosen = [x for x in chosen if x.lower() != already.lower()]
                    seen.discard(already.lower())
                break
        if not is_sub:
            chosen.append(org_clean)
            seen.add(low)

    # 3) sort by length desc so longer company names appear first (reasonable heuristic)
    chosen = sorted(chosen, key=lambda x: -len(x))
    return chosen

def postprocess_entities(model_ner_out: List[Dict], rule_out: Dict) -> Dict:
    """
    model_ner_out: list of pipeline outputs (with keys 'entity_group' and 'word')
    rule_out: result of rule_match(text) with keys indices,sectors,regulators,custom_companies,kpis,policies,products
    """
    # extract raw lists
    model_orgs = [ _clean_span(e["word"]) for e in model_ner_out if e.get("entity_group","").upper() in ("ORG","MISC","MISCELLANEOUS") ]
    model_pers = [ _clean_span(e["word"]) for e in model_ner_out if e.get("entity_group","").upper() in ("PER","PERSON") ]
    model_locs = [ _clean_span(e["word"]) for e in model_ner_out if e.get("entity_group","").upper() in ("LOC","GPE") ]

    # rule outputs
    gaz_companies = rule_out.get("custom_companies", []) or []
    regulators = rule_out.get("regulators", []) or []
    kpis = rule_out.get("kpis", []) or []
    fint = rule_out.get("policies", []) or []   # financial terms
    indices = rule_out.get("indices", []) or []
    sectors = rule_out.get("sectors", []) or []
    products = rule_out.get("products", []) or []
    money = rule_out.get("money", []) or []
    percent = rule_out.get("percent", []) or []

    # Build companies list with prioritization & filtering
    companies = prioritize_companies(model_orgs, gaz_companies, regulators, kpis, fint)

    # If gazetteer had a custom company and model didn't find it, ensure it is included
    for g in gaz_companies:
        if g and g.lower() not in [c.lower() for c in companies]:
            companies.append(_clean_span(g))

    # Final normalization / dedupe for all lists
    def dedupe_list(lst):
        out = []
        seen = set()
        for item in lst:
            if not item: 
                continue
            k = item.strip()
            if k.lower() in seen:
                continue
            seen.add(k.lower())
            out.append(k)
        return out

    final = {
        "companies": dedupe_list(companies),
        "people": dedupe_list(model_pers),
        "locations": dedupe_list(model_locs),
        "indices": dedupe_list(indices),
        "sectors": dedupe_list(sectors),
        "regulators": dedupe_list(regulators),
        "policies": dedupe_list(fint),
        "products": dedupe_list(products),
        "kpis": dedupe_list(kpis),
        "money": dedupe_list(money),
        "percent": dedupe_list(percent)
    }

    # Remove any company token that exactly equals a regulator or KPI
    final["companies"] = [c for c in final["companies"] if c.lower() not in REGULATOR_SET and c.lower() not in KPI_SET and c.lower() not in FINTERM_SET]

    return final


In [15]:
ner_out = fin_ner(sample_text)             # pipeline output
rules = match_rules(sample_text)       # gazetteer + regex
cleaned = postprocess_entities(ner_out, rules)
cleaned


{'companies': ['hdfc bank', 'reliance'],
 'people': [],
 'locations': [],
 'indices': ['nifty', 'nifty 50'],
 'sectors': [],
 'regulators': ['rbi'],
 'policies': ['repo rate', 'inflation'],
 'products': [],
 'kpis': ['Q2 results', 'EBITDA'],
 'money': [],
 'percent': ['12%']}

In [16]:
def longest_match_gazetteer(text: str, phrase_list: list):
    """
    Returns a list of matched phrases using longest-match-first logic.
    Prevents 'nifty' from matching when 'nifty 50' is present.
    """
    text_lower = text.lower()
    matches = []
    occupied = [False] * len(text_lower)

    sorted_phrases = sorted(phrase_list, key=lambda x: -len(x))

    for phrase in sorted_phrases:
        p = phrase.lower()
        start_idx = text_lower.find(p)
        while start_idx != -1:
            end_idx = start_idx + len(p)
            if not any(occupied[start_idx:end_idx]):
                matches.append(phrase)

                for i in range(start_idx, end_idx):
                    occupied[i] = True

            start_idx = text_lower.find(p, start_idx + 1)

    return list(set(matches))  # dedupe


In [17]:
def match_rules(text):
    tl = text.lower()

    return {
        "indices": normalize(longest_match_gazetteer(text, INDICES)),
        "sectors": normalize([s for s in SECTORS if s in tl]),
        "regulators": normalize([r for r in REGULATORS if r in tl]),
        "policies": normalize([t for t in FIN_TERMS if t in tl]),
        "custom_companies": normalize([c for c in COMPANIES_CUSTOM if c in tl]),
        "products": normalize([p for p in PRODUCTS if p in tl]),
        "kpis": normalize([m.group(0) for m in KPI_REGEX.finditer(text)]),
        "money": normalize([m.group(0) for m in MONEY_REGEX.finditer(text)]),
        "percent": normalize([m.group(0) for m in PERCENT_REGEX.finditer(text)])
    }

In [18]:
ner_out = fin_ner(sample_text)             # pipeline output
rules = match_rules(sample_text)       # gazetteer + regex
cleaned = postprocess_entities(ner_out, rules)
cleaned

{'companies': ['hdfc bank', 'reliance'],
 'people': [],
 'locations': [],
 'indices': ['nifty 50'],
 'sectors': [],
 'regulators': ['rbi'],
 'policies': ['repo rate', 'inflation'],
 'products': [],
 'kpis': ['Q2 results', 'EBITDA'],
 'money': [],
 'percent': ['12%']}

In [None]:
conn = sqlite3.connect("../data/financial_news.db")
cur = conn.cursor()
cur.execute(
            """ 
            CREATE TABLE IF NOT EXISTS news_entities (
            id INTEGER PRIMARY KEY AUTOINCREMENT,
            story_id INT,
            article_ids TEXT,
            article_title TEXT,
            companies TEXT,
            sectors TEXT,
            people TEXT,
            indices TEXT,
            regulators TEXT,
            policies TEXT,
            products TEXT,
            locations TEXT,
            kpis TEXT,
            financial_terms TEXT,
            created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
            );
            """
        )
conn.commit()
cur.close()
conn.close()

In [23]:
conn = sqlite3.connect("../data/financial_news.db")

def fetch_stories():
    cur = conn.cursor()
    sql = "SELECT id, article_ids, article_title, combined_text, num_articles FROM unique_news ORDER BY id;"
    cur.execute(sql)
    rows = cur.fetchall()
    cur.close()
    return rows


In [31]:
rows = fetch_stories()
for row in rows:
    print(row)
    break

(1, '[np.int64(1)]', 'Worried About Inflation? These 3 ETFs Offer Real Protection', 'Inflation has slowed but remains a major concern for many investors; these ETFs can help provide a buffer through the use of TIPS, commodities, or T-Bills.', 1)


In [32]:
def save_entities(story, story_id, story_article_ids, story_title):
    conn = sqlite3.connect("../data/financial_news.db")
    cur = conn.cursor()

    cur.execute(
        """
            INSERT INTO news_entities 
            (story_id, article_ids, article_title, companies, sectors, people, indices, regulators, policies, products, locations, kpis, financial_terms)
            VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
            """, (
                story_id,
                story_article_ids,
                story_title,
                json.dumps(story.get("companies", [])),
                json.dumps(story.get("sectors", [])),
                json.dumps(story.get("people", [])),
                json.dumps(story.get("indices", [])),
                json.dumps(story.get("regulators", [])),
                json.dumps(story.get("policies", [])),
                json.dumps(story.get("products", [])),
                json.dumps(story.get("locations", [])),
                json.dumps(story.get("kpis", [])),
                json.dumps(story.get("financial_terms", [])),
            )
        )
    conn.commit()
    cur.close()
    conn.close()


rows = fetch_stories()
for story in rows:
    ner_out = fin_ner(story[3])             # pipeline output
    rules = match_rules(story[3])       # gazetteer + regex
    cleaned = postprocess_entities(ner_out, rules)
    save_entities(cleaned, story_id=story[0], story_article_ids=story[1], story_title=story[2])

print("SUCCESS")

SUCCESS
