# Inference Demo: Fair LoRA Resume-Job Matching

本 Notebook 示範如何載入 BGE base + LoRA adapter 或訓練後的本地 checkpoint，計算履歷與職位描述的匹配分數 (match score)。

## 內容
1. 環境與匯入
2. 模型與 tokenizer 載入 (遠端 Adapter 或本地最佳模型)
3. 單 pair 推論
4. 批次多 pair 推論
5. (選用) 敏感學校名稱遮蔽
6. 匹配分數解讀
匹配分數流程：
- 取最後隱層輸出，mean pooling 或 CLS pooling。
- L2 normalize 後點積相似度。
- sigmoid 將相似度壓縮到 (0,1)。
> 如果你已訓練並產生 `best_util_model.pt` 或 `best_fairness_model.pt`，可直接載入該檔案 (包含完整模型 state_dict)。

In [None]:
# 1. Imports & basic setup
import os, json, math
import torch
import torch.nn.functional as F
from pathlib import Path
from transformers import AutoModel, AutoTokenizer

# Try optional peft import; fall back gracefully if missing
try:
    from peft import PeftModel  # type: ignore
except Exception as e:
    PeftModel = None  # allows running without adapter
    print("peft not installed; remote LoRA adapter loading will be skipped. Install with `pip install peft`. ")

DEVICE = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu'))
BASE_MODEL = 'BAAI/bge-large-en-v1.5'
print(f'Device: {DEVICE}')

In [None]:
# 2. Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
# Optionally restrict max length for speed/memory:
tokenizer.model_max_length = 256
print('Tokenizer loaded.')

## 選擇載入方式
- Option A: Hugging Face Hub 上的 LoRA adapter (需要 adapter repo)。
- Option B: 使用訓練流程產生的本地 checkpoint (`models/fair_adversarial/best_util_model.pt`)。

若使用 Option B，因訓練代碼採用自訂 `FairLoRAModel`，這裡示範一個輕量的推論版本，直接使用 base model + pooling。

In [None]:
# 3A. Load base + remote LoRA adapter (if available)
USE_REMOTE_ADAPTER = False  # 改 True 如果要使用 HuggingFace 上的 LoRA adapter
REMOTE_ADAPTER_PATH = 'shashu2325/resume-job-matcher-lora'  # 範例 adapter repo

base_model = AutoModel.from_pretrained(BASE_MODEL, torch_dtype=torch.float32)
if USE_REMOTE_ADAPTER and PeftModel is not None:
    try:
        model = PeftModel.from_pretrained(base_model, REMOTE_ADAPTER_PATH)
        print(f"Loaded remote LoRA adapter: {REMOTE_ADAPTER_PATH}")
    except Exception as e:
        print(f"Failed to load remote adapter: {e}; falling back to base model.")
        model = base_model
else:
    model = base_model
    if USE_REMOTE_ADAPTER and PeftModel is None:
        print("peft not available; using base model without adapter.")

model.to(DEVICE)
model.eval()
print('Model ready.')

In [None]:
# 3B. (Optional) Load local fairness-aware checkpoint weights if available
# This assumes the checkpoint contains a 'model_state_dict' compatible with the underlying base model.
LOCAL_CKPT_PATH = Path('models/fair_adversarial/best_util_model.pt')  # or best_fairness_model.pt
if LOCAL_CKPT_PATH.exists():
    try:
        ckpt = torch.load(LOCAL_CKPT_PATH, map_location=DEVICE)
        state = ckpt.get('model_state_dict', None)
        if state is None and isinstance(ckpt, dict):
            # sometimes saved as plain state_dict
            state = ckpt
        if state is not None:
            missing, unexpected = model.load_state_dict(state, strict=False)
            print(f'Loaded local checkpoint: {LOCAL_CKPT_PATH}')
            print('  Missing keys:', len(missing), '| Unexpected keys:', len(unexpected))
        else:
            print('Checkpoint found but no model_state_dict; skipped loading.')
    except Exception as e:
        print(f'Failed to load local checkpoint: {e}')
else:
    print('Local checkpoint not found; using current model weights.')

## 4. Helper: Text -> Embedding
采用 mean pooling + L2 normalize。可依需求改成 CLS token 或加權 pooling。

In [None]:
def encode_text(text: str, max_length: int = 256):
    inputs = tokenizer(text, return_tensors='pt', max_length=max_length, truncation=True, padding='max_length')
    inputs = {k: v.to(DEVICE) for k,v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        hidden = outputs.last_hidden_state  # (batch, seq, dim)
        emb = hidden.mean(dim=1)  # mean pooling
        emb = F.normalize(emb, p=2, dim=1)
    return emb.cpu()  # return on CPU for similarity ops

def match_score(resume_text: str, job_text: str) -> float:
    r_emb = encode_text(resume_text)
    j_emb = encode_text(job_text)
    # cosine similarity (since normalized) == dot product
    sim = torch.sum(r_emb * j_emb, dim=1)
    # Optionally transform to (0,1) via sigmoid; or just use raw cosine similarity
    return torch.sigmoid(sim).item(), sim.item()

print('Encoder helpers ready.')

## 5. Single Pair Inference

In [None]:
resume_text = 'Software engineer with Python experience building backend services and APIs.'
job_text = 'Seeking a backend Python developer to design scalable microservices.'
prob_score, raw_cosine = match_score(resume_text, job_text)
print(f'Match probability (sigmoid of cosine): {prob_score:.4f}')
print(f'Raw cosine similarity: {raw_cosine:.4f}')

## 6. Batch Multiple Pairs
可一次比較多組 pair，方便做快速排名。

In [None]:
pairs = [
    ('Data scientist experienced in NLP and transformers.', 'Looking for ML engineer with NLP background'),
    ('Frontend developer skilled in React and TypeScript.', 'Need React engineer for UI component development'),
    ('Project manager with agile certification.', 'Hiring Scrum master for cross-team coordination'),
    ('Graphic designer using Figma and Adobe suite.', 'Seeking UX/UI designer to craft product interfaces')
]
results = []
for r, j in pairs:
    prob, cosine = match_score(r, j)
    results.append({'resume': r[:50]+'...', 'job': j[:50]+'...', 'prob_score': prob, 'cosine': cosine})

for row in results:
    print(f"Resume: {row['resume']}\nJob: {row['job']}\n  prob={row['prob_score']:.4f} | cosine={row['cosine']:.4f}\n")