In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import re
from bs4 import BeautifulSoup
from transformers import BertTokenizerFast, BertModel
from sklearn.preprocessing import LabelEncoder
from torch.utils.data import DataLoader, TensorDataset
import os

# === Constants ===
START_TAG = "<START>"
STOP_TAG = "<END>"
LABELS = ["B-SALARY", "I-SALARY", "O"]
CACHE_FILE = "cached_train_embeddings.pt"

# === Helper Functions ===
def parse_actual_info(info_str):
    parts = info_str.split('-')
    if len(parts) != 4 or parts == ['0', '0', 'None', 'None']:
        return None
    return (float(parts[0]), float(parts[1]), parts[2], parts[3].lower())

def clean_html_tags(html_text):
    soup = BeautifulSoup(html_text, 'html.parser')
    return soup.get_text()

def clean_text(text):
    if pd.isna(text):
        return ""
    cleaned = re.sub(r'<[^>]+>', '', text)           
    cleaned = re.sub(r'\s+', ' ', cleaned).strip()   
    return cleaned

def chunk_and_align(text, min_salary, max_salary, tokenizer, max_length=512, stride=128):
    tokens_all, labels_all, embeddings_all = [], [], []
    inputs = tokenizer(
        text,
        return_offsets_mapping=True,
        truncation=True,
        max_length=max_length,
        stride=stride,
        return_overflowing_tokens=True,
        return_tensors="pt",
        padding="max_length"
    )

    overflow_mapping = inputs.pop("overflow_to_sample_mapping")
    offset_mappings = inputs.pop("offset_mapping")
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']

    for i in range(len(input_ids)):
        chunk_offsets = offset_mappings[i].tolist()
        chunk_input_ids = input_ids[i]
        chunk_tokens = tokenizer.convert_ids_to_tokens(chunk_input_ids)
        word_ids = inputs.word_ids(i)

        labels = []
        for token, offset, word_id in zip(chunk_tokens, chunk_offsets, word_ids):
            if word_id is None or offset == [0, 0]:
                labels.append("O")
                continue
            word = text[offset[0]:offset[1]]
            try:
                value = float(re.sub(r'[^\d.]', '', word))
                if min_salary <= value <= max_salary:
                    if labels and labels[-1] in ["B-SALARY", "I-SALARY"]:
                        labels.append("I-SALARY")
                    else:
                        labels.append("B-SALARY")
                else:
                    labels.append("O")
            except:
                labels.append("O")

        input_dict = {
            "input_ids": chunk_input_ids.unsqueeze(0).to(device),
            "attention_mask": attention_mask[i].unsqueeze(0).to(device)
        }
        with torch.no_grad():
            outputs = bert_model(**input_dict)
            embeddings = outputs.last_hidden_state.squeeze(0)

        tokens_all.append(chunk_tokens)
        labels_all.append(labels)
        embeddings_all.append(embeddings)

    return tokens_all, labels_all, embeddings_all

def log_sum_exp(vec):
    max_score = vec[0, torch.argmax(vec, 1)]
    max_score_broadcast = max_score.view(1, -1).expand(1, vec.size()[1])
    return max_score + torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))

def prepare_sequence(embeds):
    return embeds.view(len(embeds), 1, -1)

# === BiLSTM+CRF Model ===
class BiLSTM_CRF(nn.Module):
    def __init__(self, embedding_dim, hidden_dim, tag_to_ix):
        super(BiLSTM_CRF, self).__init__()
        self.hidden_dim = hidden_dim
        self.tag_to_ix = tag_to_ix
        self.tagset_size = len(tag_to_ix)

        self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2, num_layers=1, bidirectional=True)
        self.hidden2tag = nn.Linear(hidden_dim, self.tagset_size)

        self.transitions = nn.Parameter(torch.randn(self.tagset_size, self.tagset_size))
        self.transitions.data[tag_to_ix[START_TAG], :] = -10000
        self.transitions.data[:, tag_to_ix[STOP_TAG]] = -10000
        self.hidden = self.init_hidden()

    def init_hidden(self):
        return (torch.randn(2, 1, self.hidden_dim // 2), torch.randn(2, 1, self.hidden_dim // 2))

    def _forward_alg(self, feats):
        init_alphas = torch.full((1, self.tagset_size), -10000.)
        init_alphas[0][self.tag_to_ix[START_TAG]] = 0.
        forward_var = init_alphas

        for feat in feats:
            alphas_t = []
            for next_tag in range(self.tagset_size):
                emit_score = feat[next_tag].view(1, -1).expand(1, self.tagset_size)
                trans_score = self.transitions[next_tag].view(1, -1)
                next_tag_var = forward_var + trans_score + emit_score
                alphas_t.append(log_sum_exp(next_tag_var).view(1))
            forward_var = torch.cat(alphas_t).view(1, -1)
        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        alpha = log_sum_exp(terminal_var)
        return alpha

    def _get_lstm_features(self, embeds):
        self.hidden = self.init_hidden()
        lstm_out, _ = self.lstm(embeds, self.hidden)
        lstm_out = lstm_out.view(len(embeds), self.hidden_dim)
        lstm_feats = self.hidden2tag(lstm_out)
        return lstm_feats

    def _score_sentence(self, feats, tags):
        score = torch.zeros(1)
        tags = torch.cat([torch.tensor([self.tag_to_ix[START_TAG]], dtype=torch.long), tags])
        for i, feat in enumerate(feats):
            score += self.transitions[tags[i + 1], tags[i]] + feat[tags[i + 1]]
        score += self.transitions[self.tag_to_ix[STOP_TAG], tags[-1]]
        return score

    def _viterbi_decode(self, feats):
        backpointers = []
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][self.tag_to_ix[START_TAG]] = 0
        forward_var = init_vvars

        for feat in feats:
            bptrs_t, viterbivars_t = [], []
            for next_tag in range(self.tagset_size):
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = torch.argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        terminal_var = forward_var + self.transitions[self.tag_to_ix[STOP_TAG]]
        best_tag_id = torch.argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        start = best_path.pop()
        assert start == self.tag_to_ix[START_TAG]
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, embeds, tags):
        feats = self._get_lstm_features(embeds)
        forward_score = self._forward_alg(feats)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, embeds):
        lstm_feats = self._get_lstm_features(embeds)
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

# === Setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nation_currency = {"PH": "PHP", "NZ": "NZD", "AUS": "AUD", "HK": "HKD", "ID": "IDR", "MY": "MYR", "SG": "SGD", "TH": "THB"}
tokenizer = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
bert_model = BertModel.from_pretrained("bert-base-multilingual-cased").to(device).eval()
for param in bert_model.parameters():
    param.requires_grad = False

# === Data Preprocessing ===
dev_data = pd.read_csv("/Users/eddiezhang/Downloads/job_data_files/salary_labelled_development_set.csv")
dev_data['currency'] = dev_data.iloc[:, 3].map(nation_currency)
dev_data['parsed'] = dev_data.iloc[:, 5].apply(parse_actual_info)
dev_data[['min_salary', 'max_salary', 'currency', 'unit']] = pd.DataFrame(dev_data['parsed'].tolist(), index=dev_data.index)
dev_data['cleaned_ad_details'] = dev_data['job_ad_details'].astype(str).apply(clean_html_tags).apply(clean_text)

# === Build Training Set (with caching) ===
tag_to_ix = {label: i for i, label in enumerate(LABELS)}
tag_to_ix[START_TAG] = len(tag_to_ix)
tag_to_ix[STOP_TAG] = len(tag_to_ix)

if os.path.exists(CACHE_FILE):
    print("Loading cached embeddings...")
    X_train, Y_train = torch.load(CACHE_FILE)
else:
    print("Generating embeddings...")
    X_train, Y_train = [], []
    for idx, row in dev_data.iterrows():
        job_text = row['cleaned_ad_details']
        parsed = row['parsed']
        if not parsed:
            continue
        min_salary, max_salary = parsed[0], parsed[1]

        token_chunks, label_chunks, embed_chunks = chunk_and_align(job_text, min_salary, max_salary, tokenizer)

        for labels, embeddings in zip(label_chunks, embed_chunks):
            if len(labels) != embeddings.shape[0]:
                continue
            X_train.append(embeddings)
            Y_train.append(torch.tensor([tag_to_ix[lbl] for lbl in labels], dtype=torch.long))

    torch.save((X_train, Y_train), CACHE_FILE)
    print("Embeddings cached.")

# === Train Model ===
model = BiLSTM_CRF(embedding_dim=768, hidden_dim=128, tag_to_ix=tag_to_ix).to(device)
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(3):
    total_loss = 0.0
    for x, y in zip(X_train, Y_train):
        model.zero_grad()
        feats = prepare_sequence(x).to(device)
        y = y.to(device)
        loss = model.neg_log_likelihood(feats, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}")


Generating embeddings...
Embeddings cached.
Epoch 1, Loss: 5054.3152
