In [1]:
# Cell 1 - (optional) install packages if needed (uncomment to run)
!pip install -q gradio sentence-transformers scikit-learn xgboost joblib deep_translator


[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
!pip install langdetect deep-translator
!pip install ipywidgets




[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Collecting ipywidgets
  Using cached ipywidgets-8.1.7-py3-none-any.whl.metadata (2.4 kB)
Collecting widgetsnbextension~=4.0.14 (from ipywidgets)
  Using cached widgetsnbextension-4.0.14-py3-none-any.whl.metadata (1.6 kB)
Collecting jupyterlab_widgets~=3.0.15 (from ipywidgets)
  Using cached jupyterlab_widgets-3.0.15-py3-none-any.whl.metadata (20 kB)
Using cached ipywidgets-8.1.7-py3-none-any.whl (139 kB)
Using cached jupyterlab_widgets-3.0.15-py3-none-any.whl (216 kB)
Using cached widgetsnbextension-4.0.14-py3-none-any.whl (2.2 MB)
Installing collected packages: widgetsnbextension, jupyterlab_widgets, ipywidgets
Successfully installed ipywidgets-8.1.7 jupyterlab_widgets-3.0.15 widgetsnbextension-4.0.14



[notice] A new release of pip is available: 24.0 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [3]:
import os, io, json, warnings, tempfile
import numpy as np
import pandas as pd
import joblib
import torch
from sentence_transformers import SentenceTransformer

import random, numpy as np, torch
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

# Optional libs (translation / lang detect). Nếu thiếu, thông báo rõ.
try:
    from deep_translator import GoogleTranslator
except Exception:
    GoogleTranslator = None
    warnings.warn("`deep_translator` không cài đặt — chức năng translate sẽ fallback (không dịch).")

try:
    from langdetect import detect
except Exception:
    detect = None
    warnings.warn("`langdetect` không cài đặt — sẽ dùng heuristic ASCII để kiểm tra ngôn ngữ.")

try:
    import gradio as gr
except Exception:
    gr = None
    warnings.warn("`gradio` không cài đặt. Cài nếu muốn UI local (pip install gradio).")

print("Environment ready. Torch device:", "cuda" if torch.cuda.is_available() else "cpu")

Environment ready. Torch device: cuda


In [4]:
# Cell 2 - Embedder (SentenceTransformer)
MODEL_NAME = "all-MiniLM-L6-v2"
device = "cuda" if torch.cuda.is_available() else "cpu"
embedder = SentenceTransformer(MODEL_NAME, device=device)
print("Loaded embedder:", MODEL_NAME, "on", device)


Loaded embedder: all-MiniLM-L6-v2 on cuda


In [5]:
# Cell 3 - MBTI -> secondary roles, secondary -> main core role mappings
CORE_ROLE_NAMES = ["Leader", "Planner", "Executor", "Facilitator"]

# A plausible mapping MBTI -> secondary roles (2 per MBTI). 
# Nếu bạn có mapping gốc, thay vào biến này để kết quả giống bản cũ.
mbti_roles = {
        "INFP": ["Supporter", "Idea Generator"],
        "INFJ": ["Leader", "Supporter"],
        "ENFP": ["Idea Generator", "Communicator"],
        "ENFJ": ["Leader", "Supporter"],
        "INTP": ["Idea Generator", "Checker"],
        "INTJ": ["Leader", "Checker"],
        "ENTP": ["Idea Generator", "Leader"],
        "ENTJ": ["Leader", "Implementer"],
        "ISFJ": ["Supporter", "Finisher"],
        "ESFJ": ["Supporter", "Communicator"],
        "ISTJ": ["Checker", "Finisher"],
        "ESTJ": ["Leader", "Implementer"],
        "ISFP": ["Supporter", "Implementer"],
        "ESFP": ["Communicator", "Idea Generator"],
        "ISTP": ["Implementer", "Problem Solver"],
        "ESTP": ["Leader", "Implementer"],
}

# Map detailed roles (from mbti_roles) -> 4 main roles
# This version is cleaned up to only include roles present in mbti_roles
role_to_main = {
    # Roles mapping to Leader
    "Leader": "Leader",
    # Roles mapping to Planner
    "Idea Generator": "Planner",
    "Checker": "Planner",
    # Roles mapping to Executor
    "Implementer": "Executor",
    "Finisher": "Executor",
    "Problem Solver": "Executor",
    # Roles mapping to Facilitator
    "Supporter": "Facilitator",
    "Communicator": "Facilitator",
}

# Ensure every CORE_ROLE_NAMES key exists in role_probs default later
# This loop also ensures core roles like "Leader" are mapped to themselves
for c in CORE_ROLE_NAMES:
    role_to_main.setdefault(c, c)

print("Role mappings loaded. Core roles:", CORE_ROLE_NAMES)

Role mappings loaded. Core roles: ['Leader', 'Planner', 'Executor', 'Facilitator']


In [6]:
# Cell 4 - Load binary models (tries folder per-axis or a single joblib)
def load_binary_models(base_dir="src/binary_model", single_paths=None):
    single_paths = single_paths or ["primary_mbti_clf.joblib", "binary_models.joblib", "binary_model.joblib"]
    AXES = ["EI","SN","TF","JP"]
    models = {}
    # try per-axis folders
    for axis in AXES:
        model_path = os.path.join(base_dir, axis, f"xgb_{axis}.joblib")
        map_path = os.path.join(base_dir, axis, "label_map.json")
        if os.path.exists(model_path):
            try:
                clf = joblib.load(model_path)
                id2label = None
                if os.path.exists(map_path):
                    with open(map_path, "r", encoding="utf-8") as f:
                        mm = json.load(f)
                    id2label = mm.get("id2label") or mm.get("id_to_label") or mm.get("id2label")
                models[axis] = {"clf": clf, "id2label": id2label}
            except Exception as e:
                warnings.warn(f"Không thể load {model_path}: {e}")
    if len(models) == 4:
        print("Loaded binary models from folders:", list(models.keys()))
        return models

    # try single joblib files
    for p in single_paths:
        if not os.path.exists(p):
            continue
        try:
            data = joblib.load(p)
            if isinstance(data, dict):
                # check if dict contains four axis
                if all(k in data for k in ["EI","SN","TF","JP"]):
                    out = {}
                    for axis in ["EI","SN","TF","JP"]:
                        val = data[axis]
                        if isinstance(val, dict) and "clf" in val:
                            out[axis] = {"clf": val["clf"], "id2label": val.get("id2label")}
                        else:
                            out[axis] = {"clf": val, "id2label": None}
                    print(f"Loaded binary models from single joblib: {p}")
                    return out
            # if it's a single classifier, assign to EI as fallback
            if hasattr(data, "predict_proba"):
                warnings.warn(f"Found single model in {p}. Assigned to axis EI as fallback.")
                return {"EI": {"clf": data, "id2label": None}}
        except Exception as e:
            warnings.warn(f"Error loading {p}: {e}")

    warnings.warn("No binary models found. Inference will fallback to random probs.")
    return {}

# Load now
binary_models = load_binary_models()


Loaded binary models from folders: ['EI', 'SN', 'TF', 'JP']


In [7]:
# Cell 5 - language detect & translate helper + embedding helper
def detect_is_english(text):
    if not text or text.strip()=="":
        return True
    if detect:
        try:
            lang = detect(text)
            return lang.lower().startswith("en")
        except Exception:
            pass
    # fallback heuristic: all ascii => probably english
    return all(ord(c) < 128 for c in text)

def maybe_translate(text, enable_translate=True, target="en"):
    if not enable_translate:
        return text
    if GoogleTranslator is None:
        # no translator lib installed
        return text
    try:
        if detect_is_english(text):
            return text
        return GoogleTranslator(source='auto', target=target).translate(text)
    except Exception as e:
        warnings.warn(f"Translate failed: {e}")
        return text

def encode_posts_mean(posts, embedder=embedder, batch_size=32):
    parts = [p.strip() for p in str(posts).split("|||") if p.strip()]
    if len(parts) == 0:
        dim = embedder.get_sentence_embedding_dimension()
        return np.zeros(dim, dtype=np.float32)
    emb = embedder.encode(parts, convert_to_numpy=True, batch_size=batch_size)
    if emb.ndim == 1:
        return emb
    return emb.mean(axis=0)


In [8]:
# Cell 6 - inference: axis probs -> 16-type probs -> classify user
AXES = ["EI","SN","TF","JP"]
DEFAULT_LETTERS = {"EI":("E","I"), "SN":("S","N"), "TF":("T","F"), "JP":("J","P")}

def _ensure_id2label(m):
    if m is None:
        return None
    # id2label may have keys as str; normalize to int keys
    try:
        return {int(k): v for k,v in m.items()}
    except Exception:
        return m

def predict_mbti_from_embedding(embedding, binary_models=binary_models):
    # fallback: uniform random
    if not binary_models:
        mbti_list = [a+b+c+d for a in ["E","I"] for b in ["S","N"] for c in ["T","F"] for d in ["J","P"]]
        r = np.random.rand(len(mbti_list))
        r /= r.sum()
        axis_probs = {axis: {DEFAULT_LETTERS[axis][0]: 0.5, DEFAULT_LETTERS[axis][1]: 0.5} for axis in AXES}
        top = mbti_list[int(np.argmax(r))]
        return top, float(np.max(r)), axis_probs, dict(zip(mbti_list, r))

    axis_probs = {}
    for axis in AXES:
        info = binary_models.get(axis)
        if info is None:
            axis_probs[axis] = {DEFAULT_LETTERS[axis][0]:0.5, DEFAULT_LETTERS[axis][1]:0.5}
            continue
        clf = info["clf"]
        id2label = _ensure_id2label(info.get("id2label"))
        try:
            probs = clf.predict_proba([embedding])[0]
            classes = list(clf.classes_)
        except Exception as e:
            warnings.warn(f"Predict_proba error for axis {axis}: {e}")
            axis_probs[axis] = {DEFAULT_LETTERS[axis][0]:0.5, DEFAULT_LETTERS[axis][1]:0.5}
            continue

        this_axis = {}
        for idx, cls in enumerate(classes):
            # map cls -> letter
            mapped = None
            if id2label and str(cls) in id2label:
                mapped = id2label[str(cls)]
            elif isinstance(cls, str) and cls in DEFAULT_LETTERS[axis]:
                mapped = cls
            else:
                try:
                    num = int(cls)
                    mapped = DEFAULT_LETTERS[axis][num]
                except Exception:
                    mapped = DEFAULT_LETTERS[axis][idx]
            this_axis[mapped] = float(probs[idx])
        # ensure both letters exist, normalize
        l0,l1 = DEFAULT_LETTERS[axis]
        this_axis.setdefault(l0, 1e-9)
        this_axis.setdefault(l1, 1e-9)
        s = this_axis[l0] + this_axis[l1]
        this_axis[l0] /= s; this_axis[l1] /= s
        axis_probs[axis] = this_axis

    # compute 16-type probability by multiplying axis probabilities
    mbti_list = [a+b+c+d for a in ["E","I"] for b in ["S","N"] for c in ["T","F"] for d in ["J","P"]]
    mbti_probs = {}
    for mbti in mbti_list:
        p = 1.0
        for i, axis in enumerate(AXES):
            p *= axis_probs[axis].get(mbti[i], 0.0)
        mbti_probs[mbti] = p
    total = sum(mbti_probs.values())
    if total > 0:
        for k in mbti_probs:
            mbti_probs[k] /= total
    else:
        for k in mbti_probs:
            mbti_probs[k] = 1.0/len(mbti_probs)

    top_mbti = max(mbti_probs.items(), key=lambda x: x[1])[0]
    top_conf = float(mbti_probs[top_mbti])
    return top_mbti, top_conf, axis_probs, mbti_probs

def classify_user_binary(user_dict, binary_models=binary_models):
    emb = encode_posts_mean(user_dict.get("posts",""))
    top_mbti, top_conf, axis_probs, mbti_probs = predict_mbti_from_embedding(emb, binary_models)

    # aggregate role probabilities (keeping original design: use max if multiple MBTI map to same sec role)
    sec_role_probs = {}
    main_role_probs = {}
    for mbti, p in mbti_probs.items():
        secs = mbti_roles.get(mbti, [])
        for sec in secs:
            sec_role_probs[sec] = max(sec_role_probs.get(sec, 0.0), p)
            main = role_to_main.get(sec)
            if main:
                main_role_probs[main] = max(main_role_probs.get(main, 0.0), p)
    for c in CORE_ROLE_NAMES:
        main_role_probs.setdefault(c, 0.0)

    # top secondary roles
    sec_sorted = sorted(sec_role_probs.items(), key=lambda x: x[1], reverse=True)
    sec_top = [s for s,_ in sec_sorted[:2]]

    return {
        "name": user_dict.get("name"),
        "top_mbti": top_mbti,
        "top_conf": float(top_conf),
        "axis_probs": axis_probs,
        "mbti_probs": mbti_probs,
        "role_probs": main_role_probs,
        "sec_roles_top": sec_top
    }


In [9]:
# Cell 7 - Parse .txt with format: "Name, post1 ||| post2 ||| ..."
def load_users_from_txt(path):
    # accepts path or file-like (tempfile)
    if hasattr(path, "name"):
        fp = path.name
    else:
        fp = path
    users = []
    with open(fp, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln: continue
            # Split on first comma to allow commas in name
            if "," in ln:
                name, rest = ln.split(",", 1)
            else:
                # fallback: assign generic name
                name = ln[:30]
                rest = ln
            posts = rest.strip()
            users.append({"name": name.strip(), "posts": posts})
    return users

# Quick test (uncomment to test)
# print(load_users_from_txt("sample.txt"))


In [10]:
# Cell 8 - grouping (core-first)
def pick_best_for_role(candidates, role):
    # candidates: list of enriched user dicts (with role_probs, top_conf)
    best = None
    best_score = -1.0
    for u in candidates:
        score = u["role_probs"].get(role, 0.0)
        # tie-breaker: top_conf
        if score > best_score or (abs(score-best_score) < 1e-9 and u["top_conf"] > (best["top_conf"] if best else 0)):
            best = u; best_score = score
    return best, best_score

def group_users_core_first(enriched_users, team_size=4):
    """
    enriched_users: list of dicts returned by classify_user_binary
    team_size: int (max 20)
    returns assigned list of dicts: name, assigned_role, why, top_mbti, top_conf, role_prob
    """
    team_size = min(int(team_size), 20)
    users = list(enriched_users)  
    assigned = []
    assigned_names = set()

    note = None
    if team_size < 4:
        note = "⚠️ Team nhỏ hơn tiêu chuẩn, thiếu core roles."

    # STEP 1: Đảm bảo 4 core roles (nếu có đủ người)
    core_needed = CORE_ROLE_NAMES.copy()
    remaining = users[:]

    for core in core_needed:
        if len(assigned) >= team_size:
            break
        cand, score = pick_best_for_role(remaining, core)
        if cand:
            assigned.append({
                "name": cand["name"],
                "assigned_role": core,
                "why": f"Best match for {core} (prob {score:.2f})",
                "top_mbti": cand["top_mbti"],
                "top_conf": cand["top_conf"],
                "role_prob": score
            })
            assigned_names.add(cand["name"])
            remaining = [u for u in remaining if u["name"] != cand["name"]]

    # STEP 2: Ràng buộc số lượng Leader
    max_leaders = 1 if team_size < 8 else 3
    current_leaders = sum(1 for a in assigned if a["assigned_role"] == "Leader")

    # STEP 3: Lấp đầy slot còn lại
    slots_left = team_size - len(assigned)
    if slots_left > 0:
        ranked = []
        for u in remaining:
            main, rscore = max(u["role_probs"].items(), key=lambda x: x[1])
            ranked.append((u, main, rscore))
        ranked.sort(key=lambda x: (-x[2], -x[0]["top_conf"], x[0]["name"]))

        for u, main, rscore in ranked:
            if len(assigned) >= team_size:
                break
            # enforce leader constraint
            if main == "Leader" and current_leaders >= max_leaders:
                continue
            assigned.append({
                "name": u["name"],
                "assigned_role": main,
                "why": f"Fill slot with {main} (prob {rscore:.2f})",
                "top_mbti": u["top_mbti"],
                "top_conf": u["top_conf"],
                "role_prob": rscore
            })
            assigned_names.add(u["name"])
            if main == "Leader":
                current_leaders += 1

    # STEP 4: Nếu còn slot (hiếm khi) → gán secondary role
    slots_left = team_size - len(assigned)
    if slots_left > 0:
        for u in remaining:
            if len(assigned) >= team_size:
                break
            sec = u["sec_roles_top"][0] if u["sec_roles_top"] else "Supporter"
            assigned.append({
                "name": u["name"],
                "assigned_role": sec,
                "why": f"Assigned secondary role {sec} (no core slot left)",
                "top_mbti": u["top_mbti"],
                "top_conf": u["top_conf"],
                "role_prob": u["role_probs"].get(role_to_main.get(sec, sec), 0.0)
            })
            assigned_names.add(u["name"])

    # Add team note nếu cần
    if note:
        for a in assigned:
            a["why"] = note + " " + a["why"]

    return assigned



In [11]:
# Cell 9 - process pipeline for uploaded file
def _ensure_filepath(file_input):
    # file_input may be path string or file object (gradio returns TemporaryFile)
    if file_input is None:
        raise ValueError("No file provided.")
    if isinstance(file_input, str) and os.path.exists(file_input):
        return file_input
    # file-like
    if hasattr(file_input, "name") and os.path.exists(file_input.name):
        return file_input.name
    # else try to write bytes to temp file if it's a dict (gradio older)
    if isinstance(file_input, dict) and "name" in file_input and "data" in file_input:
        tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".txt")
        tmp.write(file_input["data"])
        tmp.close()
        return tmp.name
    raise ValueError("Cannot resolve uploaded file path. Received object: " + str(type(file_input)))

def process_file_and_assign(uploaded_file, team_size=4, translate_enable=False):
    """
    Returns (markdown_str, pandas.DataFrame) for Gradio outputs
    """
    # safety
    team_size = min(int(team_size), 20)
    fp = _ensure_filepath(uploaded_file)
    users = load_users_from_txt(fp)
    if not users:
        return "Không đọc được user từ file.", pd.DataFrame()

    # optional translate & classify all
    enriched = []
    for u in users:
        posts = u["posts"]
        if translate_enable:
            posts = maybe_translate(posts, enable_translate=True, target="en")
        u2 = {"name": u["name"], "posts": posts}
        res = classify_user_binary(u2, binary_models=binary_models)
        enriched.append(res)

    # group
    assigned = group_users_core_first(enriched, team_size=team_size)

    # format markdown
    md = "| # | Name | Assigned Role | Reason | MBTI (conf) | RoleProb |\n"
    md += "|---:|---|---|---|---|---|\n"
    for i, a in enumerate(assigned, start=1):
        md += f"| {i} | {a['name']} | {a['assigned_role']} | {a['why']} | {a['top_mbti']} ({a['top_conf']:.2f}) | {a['role_prob']:.2f} |\n"

    # also prepare DataFrame
    df = pd.DataFrame(assigned)
    return md, df


In [12]:
# Cell 10 - Gradio UI (paste & run)
if gr is None:
    print("Gradio not installed. Install with `pip install gradio` to launch UI.")
else:
    with gr.Blocks() as demo:
        gr.Markdown("## 🧠 MBTI Team Builder — Core-first (Binary models)")
        gr.Markdown("Upload `.txt` file with lines: `Name, post1 ||| post2 ||| ...`")
        with gr.Row():
            file_input = gr.File(label="Upload .txt file", file_types=[".txt"])
            team_slider = gr.Slider(minimum=1, maximum=20, step=1, value=4, label="Team size (max 20)")
            translate_chk = gr.Checkbox(value=False, label="Auto-translate posts to English (if not English)")
        run_btn = gr.Button("Assign Team")
        md_out = gr.Markdown()
        df_out = gr.Dataframe(headers=["name","assigned_role","why","top_mbti","top_conf","role_prob"], label="Assigned members (table)")

        def _wrap(file, team_size, translate):
            try:
                md, df = process_file_and_assign(file, team_size=team_size, translate_enable=translate)
                return md, df
            except Exception as e:
                return f"Error: {e}", pd.DataFrame()

        run_btn.click(fn=_wrap, inputs=[file_input, team_slider, translate_chk], outputs=[md_out, df_out])

    demo.launch()


* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.
