In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import re
import nltk
from bs4 import BeautifulSoup
from transformers import BertTokenizerFast, BertModel
from sklearn.preprocessing import LabelEncoder
from nltk.corpus import wordnet as wn

# === Constants ===
START_TAG = "<START>"
STOP_TAG = "<END>"
LABELS = ["B-SALARY", "I-SALARY", "O"]
NUMBER_RE = re.compile(r"\d+(?:\.\d+)?")
CLEAN_HTML_RE = re.compile(r"<[^>]+>")
MULTI_SPACE_RE = re.compile(r"\s+")

# === Salary Lexicon from WordNet ===
# Download WordNet + Open Multilingual WordNet
nltk.download('wordnet')
nltk.download('omw-1.4')

def get_au_multilingual_salary_terms():
    # Seed English terms related to salary
    seed_words = [
        "salary", "wage", "pay", "income", "compensation", 
        "remuneration", "bonus", "allowance", "earnings", "reimbursement",
        "payment", "benefit", "commission", "gratuity"
    ]

    # Prominent languages in Australia
    au_languages = [
        "eng",  # English
        "cmn",  # Mandarin
        "arb",  # Arabic
        "vie",  # Vietnamese
        "ita",  # Italian
        "ell",  # Greek
        "hin",  # Hindi
        "spa",  # Spanish
        "tgl",  # Tagalog/Filipino
        "kor",  # Korean
        "tha",  # Thai
        "urd"   # Urdu
    ]

    lexicon = set()

    for word in seed_words:
        eng_synsets = wn.synsets(word, lang="eng")
        for synset in eng_synsets:
            for lang in au_languages:
                if lang in synset._lemma_names:
                    translated = synset.lemma_names(lang)
                    lexicon.update(lemma.lower().replace("_", " ") for lemma in translated)


    return lexicon

SALARY_LEX = get_au_multilingual_salary_terms()
# === Helper Functions ===
def parse_actual_info(info_str):
    parts = info_str.split("-")
    if len(parts) != 4 or parts == ['0', '0', 'None', 'None']:
        return None
    return float(parts[0]), float(parts[1]), parts[2], parts[3].lower()

def round_numbers(text):
    # round any x.y to integer
    def _repl(m):
        return str(int(round(float(m.group()), 0)))
    return re.sub(r"\d+\.\d+", _repl, text)

def clean_text(text):
    if pd.isna(text):
        return ""
    # 1) normalize weird splits: "1500. 0" or "21. 20" → "1500.0", "21.20"
    text = re.sub(r"(\d+)\.\s+(\d+)", r"\1.\2", text)
    # 2) round any remaining decimals to ints: "1500.0" → "1500"
    text = round_numbers(text)
    # 3) strip HTML tags & collapse whitespace
    text = CLEAN_HTML_RE.sub("", text)
    text = MULTI_SPACE_RE.sub(" ", text).strip()
    return text


def clean_html_tags(html_text):
    soup = BeautifulSoup(html_text, "html.parser")
    return soup.get_text()

def chunk_and_align(text, min_salary, max_salary, tokenizer,
                    max_length=512, context=128):
    # 1) find all number matches for min and max
    occurrences = [(m.start(), m.end(), float(m.group()))
                   for m in NUMBER_RE.finditer(text)]
    if not occurrences:
        return [], [], []

    # 2) filter to only exact min/max
    min_pos = [(s, e) for (s, e, v) in occurrences
               if abs(v - min_salary) < 1e-3]
    max_pos = [(s, e) for (s, e, v) in occurrences
               if abs(v - max_salary) < 1e-3]

    # 2b) if no exact match, pick the occurrence _closest_ to the target
    if not min_pos:
        s, e, v = min(occurrences, key=lambda x: abs(x[2] - min_salary))
        min_pos = [(s, e)]
    if not max_pos:
        s, e, v = min(occurrences, key=lambda x: abs(x[2] - max_salary))
        max_pos = [(s, e)]

    # 3) pick char-span strictly between closest min & max
    start_char = min(p[0] for p in min_pos)
    end_char   = max(p[1] for p in max_pos)

    # 4) expand by a bit of context
    start_char = max(0, start_char - context)
    end_char   = min(len(text), end_char + context)
    window_text = text[start_char:end_char]

    # 5) tokenize just that slice
    inputs = tokenizer(
        window_text,
        return_offsets_mapping = True,
        truncation = True,
        max_length = max_length,
        return_tensors = "pt"
    )
    offsets = inputs.pop("offset_mapping")[0].tolist()
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])

    # 6) compute lexicon feature per token
    lex_feats = [1 if tok.lower() in SALARY_LEX else 0 for tok in tokens]
    lex_tensor = torch.tensor(lex_feats, dtype = torch.float, device = device).unsqueeze(1)

    # 7) label tokens by overlap with original numeric span
    labels = []
    saw_B = False
    for (s, e), tok in zip(offsets, tokens):
        s_orig = s + start_char
        e_orig = e + start_char
        if e_orig > start_char + context and s_orig < end_char - context:
            labels.append("B-SALARY" if not saw_B else "I-SALARY")
            saw_B = True
        else:
            labels.append("O")

    # 8) BERT pass + concat lex feature
    with torch.no_grad():
        out = bert_model(
            input_ids = inputs["input_ids"].to(device),
            attention_mask = inputs["attention_mask"].to(device)
        )
    embeddings = out.last_hidden_state.squeeze(0)
    embeddings = torch.cat([embeddings, lex_tensor], dim = 1)

    return [tokens], [labels], [embeddings]

def log_sum_exp(vec):
    max_score = vec[0, torch.argmax(vec, dim = 1)]
    max_score_b = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_b)))

def prepare_sequence(embeds):
    return embeds.view(len(embeds), 1, -1)

class BiLSTM_CRF(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, tag_to_ix):
        super(BiLSTM_CRF, self).__init__()
        self.hidden_dim = hidden_dim
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)
        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers = 1, bidirectional = True)
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)
        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (
            torch.randn(2, 1, self.hidden_dim // 2, device = device),
            torch.randn(2, 1, self.hidden_dim // 2, device = device)
        )

    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.tagset_size), -10000., device = device)
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0
        forward_var = init_alphas
        for feat in feats:
            alphas_t = []
            for next_tag in range(self.tagset_size):
                emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
                trans_score = self.transitions[next_tag].view(1, -1)
                next_tag_var = forward_var + trans_score + emit_score
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        return log_sum_exp(terminal_var)

    def _get_lstm_features(self, embeds):
        self.hidden = self.init_hidden()
        lstm_out, _ = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(embeds), self.hidden_dim)
        return self.hidden2tag(lstm_out)

    def _score_sentence(self, feats, tags):
        score = torch.zeros(1, device = device)
        tags = torch.cat([
            torch.tensor([self.tag_to_ix[START_TAG]], dtype = torch.long, device = device),
            tags
        ])
        for i, feat in enumerate(feats):
            score += self.transitions[tags[i+1], tags[i]] + feat[tags[i+1]]
        score += self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []
        init_vvars = torch.full((1, self.tagset_size), -10000., device = device)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []
            vvars_t = []
            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = torch.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                vvars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = torch.cat(vvars_t).view(1, -1) + feat
            backpointers.append(bptrs_t)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = torch.argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]
        best_path = [best_tag_id]
        for bptrs in reversed(backpointers):
            best_tag_id = bptrs[best_tag_id]
            best_path.append(best_tag_id)
        assert best_path.pop() == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, embeds, tags):
        feats = self._get_lstm_features(embeds)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, embeds):
        lstm_feats = self._get_lstm_features(embeds)
        return self._viterbi_decode(lstm_feats)

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
bert_model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device).eval()
for param in bert_model.parameters():
    param.requires_grad = False

# === Data Preprocessing ===
dev_data = pd.read_csv("/Users/eddiezhang/Downloads/job_data_files/salary_labelled_development_set.csv")
nation_currency = {
    "PH": "PHP", "NZ": "NZD", "AUS": "AUD", "HK": "HKD",
    "ID": "IDR", "MY": "MYR", "SG": "SGD", "TH": "THB"
}
dev_data['currency'] = dev_data.iloc[:, 3].map(nation_currency)
dev_data['parsed'] = dev_data.iloc[:, 5].apply(parse_actual_info)
dev_data[['min_salary', 'max_salary', 'currency', 'unit']] = pd.DataFrame(
    dev_data['parsed'].tolist(), index = dev_data.index
)
dev_data['cleaned_ad_details'] = dev_data['job_ad_details'].astype(str).apply(clean_html_tags).apply(clean_text)

# === Build Training Set ===
tag_to_ix = {label: i for i, label in enumerate(LABELS)}
tag_to_ix[START_TAG] = len(tag_to_ix)
tag_to_ix[STOP_TAG] = len(tag_to_ix)
X_train = []
Y_train = []
for _, row in dev_data.iterrows():
    parsed = row['parsed']
    if not parsed:
        continue
    token_chunks, label_chunks, embed_chunks = chunk_and_align(
        row['cleaned_ad_details'], parsed[0], parsed[1], tokenizer
    )
    for labels, embeddings in zip(label_chunks, embed_chunks):
        if len(labels) != embeddings.shape[0]:
            continue
        X_train.append(embeddings)
        Y_train.append(torch.tensor(
            [tag_to_ix[l] for l in labels],
            dtype = torch.long,
            device = device
        ))

# === Train Model ===
model = BiLSTM_CRF(
    embedding_dim = bert_model.config.hidden_size + 1,
    hidden_dim = 128,
    tag_to_ix = tag_to_ix
).to(device)
optimizer = optim.Adam(model.parameters(), lr = 0.01)

for epoch in range(10):
    total_loss = 0.0
    for x, y in zip(X_train, Y_train):
        model.zero_grad()
        feats = prepare_sequence(x)
        loss = model.neg_log_likelihood(feats, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")


[nltk_data] Downloading package wordnet to
[nltk_data]     /Users/eddiezhang/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to
[nltk_data]     /Users/eddiezhang/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!


In [21]:
def extract_span_from_tags(tokens, tags):
    # find all positions where the model predicted B- or I-SALARY
    idxs = [i for i, tag in enumerate(tags) if tag in ("B-SALARY", "I-SALARY")]
    if not idxs:
        return ""
    start, end = min(idxs), max(idxs)
    span_tokens = tokens[start : end + 1]
    return tokenizer.convert_tokens_to_string(span_tokens).strip()

def evaluate_on_test_set(test_csv_path):
    test_data = pd.read_csv(test_csv_path)
    test_data['parsed'] = test_data.iloc[:, 5].apply(parse_actual_info)
    test_data['cleaned_ad_details'] = (
        test_data['job_ad_details']
        .astype(str)
        .apply(clean_html_tags)
        .apply(clean_text)
    )

    correct = 0
    total = 0

    print("\n=== Test Set Evaluation ===\n")
    for _, row in test_data.iterrows():
        job_id = row['job_id']
        y_true = row['y_true']
        parsed = row['parsed']
        if not parsed:
            continue
        min_salary, max_salary, currency, unit = parsed
        job_text = row['cleaned_ad_details']

        token_chunks, _, embed_chunks = chunk_and_align(
            job_text, min_salary, max_salary, tokenizer
        )

        best_prediction = "NONE"
        best_raw_span = ""
        for tokens, embeddings in zip(token_chunks, embed_chunks):
            with torch.no_grad():
                feats = prepare_sequence(embeddings).to(device)
                _, pred_ids = model(feats)

            ix_to_tag = {v: k for k, v in tag_to_ix.items()}
            pred_tags = [ix_to_tag[i.item()] for i in pred_ids]

            raw_span = extract_span_from_tags(tokens, pred_tags)
            if not raw_span:
                continue

            best_raw_span = raw_span

            # extract all numbers from the raw span
            nums = re.findall(r'\d+(?:\.\d+)?', raw_span)
            if len(nums) >= 2:
                low, high = nums[0], nums[-1]
            elif len(nums) == 1:
                low = high = nums[0]
            else:
                continue

            # normalize to integers for formatting
            low_i = int(float(low))
            high_i = int(float(high))
            formatted_range = f"{low_i}-{high_i}"

            best_prediction = f"{formatted_range}-{currency.upper()}-{unit.upper()}"
            break

        status = "✅" if (
            best_prediction.replace(" ", "").lower()
            in y_true.replace(" ", "").lower()
        ) else "❌"

        print(
            f"[{status}] Job ID {job_id}\n"
            f"    Raw span:    '{best_raw_span}'\n"
            f"    Formatted:   '{best_prediction}'\n"
            f"    Expected:    '{y_true}'\n"
        )

        total += 1
        if status == "✅":
            correct += 1

    accuracy = correct / total if total > 0 else 0.0
    print(f"\nOverall Accuracy: {correct}/{total} = {accuracy:.2%}\n")

# Run Evaluation
evaluate_on_test_set("/Users/eddiezhang/Downloads/job_data_files/salary_labelled_test_set.csv")



=== Test Set Evaluation ===

[❌] Job ID 72527377
    Raw span:    '1500'
    Formatted:   '1500-1500-MYR-MONTHLY'
    Expected:    '1500-1800-MYR-MONTHLY'

[✅] Job ID 73593343
    Raw span:    '60'
    Formatted:   '60-60-HKD-HOURLY'
    Expected:    '60-60-HKD-HOURLY'

[✅] Job ID 60150523
    Raw span:    '21'
    Formatted:   '21-21-NZD-HOURLY'
    Expected:    '21-21-NZD-HOURLY'

[✅] Job ID 79030770
    Raw span:    '32'
    Formatted:   '32-32-AUD-HOURLY'
    Expected:    '32-32-AUD-HOURLY'

[❌] Job ID 68264916
    Raw span:    ''
    Formatted:   'NONE'
    Expected:    '2000-3000-MYR-MONTHLY'

[✅] Job ID 70451682
    Raw span:    '3000Incentives : RM500 - 1000Total Package : MYR 3000 [UNK] MYR 4000 / MonthOther'
    Formatted:   '3000-4000-MYR-MONTHLY'
    Expected:    '3000-4000-MYR-MONTHLY'

[❌] Job ID 72593577
    Raw span:    ''
    Formatted:   'NONE'
    Expected:    '80-90-HKD-HOURLY'

[❌] Job ID 60040933
    Raw span:    '142642 - $ 156491 Fortnightly salary $ 5468 - $ 5

KeyboardInterrupt: 

In [None]:
How many are correct?
