In [None]:

!pip install -q langchain-groq pandas scikit-learn tqdm backoff

import os, re, json, time
import pandas as pd
from tqdm import tqdm
from langchain_groq import ChatGroq
from sklearn.metrics import classification_report, accuracy_score
import backoff


CSV_PATH = "/content/final_multitask___test.csv"   
OUT_CSV  = "/content/multitask_predictions_batched.csv"
GROQ_API_KEY = os.environ.get("GROQ_API_KEY")
if not GROQ_API_KEY:
    raise RuntimeError("Set GROQ_API_KEY in environment, e.g. os.environ['GROQ_API_KEY']='gsk_...'")

PRIMARY_MODEL = "llama-3.1-8b-instant"         
FALLBACK_MODEL = "mixtral-8x7b-instant"         

BATCH_SIZE = 10      
PAUSE_BETWEEN_CALLS = 0.2  
SAMPLE_N = 0           


SENT_LABELS = ["Mixed_feelings", "Negative", "Positive"]

OFF_LABELS = [
    "Not_offensive",
    "Offensive",
    "Offensive_Targeted_Insult_Group",
    "Offensive_Targeted_Insult_Individual",
    "Offensive_Untargetede"   
]

ID_LABELS = ["Homophobia", "Non-anti-LGBT+ content", "Transphobia"]

TASK_LABELS = {
    "SA": SENT_LABELS,
    "OFFENS": OFF_LABELS,
    "OTHERCAT": ID_LABELS
}


df = pd.read_csv(CSV_PATH)
print("Loaded rows:", len(df))

for c in ("text","label","task"):
    if c not in df.columns:
        raise RuntimeError(f"Expected column '{c}' in CSV but not found. Columns: {df.columns.tolist()}")

if SAMPLE_N and SAMPLE_N>0:
    df = df.head(SAMPLE_N)
    print("Running on sample rows:", len(df))


def build_llm(model_name):
    try:
        return ChatGroq(model=model_name, groq_api_key=GROQ_API_KEY, temperature=0), model_name
    except Exception as e:
        print(f"[WARN] Could not init model {model_name}: {e}")
        raise


try:
    llm, current_model = build_llm(PRIMARY_MODEL)
    print(f"Using model: {current_model}")
except Exception:
    try:
        llm, current_model = build_llm(FALLBACK_MODEL)
        print(f"Primary model failed; switched to fallback: {current_model}")
    except Exception as e:
        raise RuntimeError(f"Both primary ({PRIMARY_MODEL}) and fallback ({FALLBACK_MODEL}) failed to initialize: {e}")


def try_extract_json_array(s: str):
    if not isinstance(s, str):
        return None
    m = re.search(r"\[.*\]", s, flags=re.S)
    if not m:
        return None
    try:
        arr = json.loads(m.group(0))
        if isinstance(arr, list):
            return arr
    except Exception:
        return None
    return None

def try_extract_json_label(s: str):
    if not isinstance(s, str):
        return ""
    m = re.search(r"\{.*?\}", s, flags=re.S)
    if not m:
        return s.strip()
    try:
        obj = json.loads(m.group(0))
        for k in ("label","classification","class"):
            if k in obj:
                return str(obj[k]).strip()
        vals = list(obj.values())
        if vals:
            return str(vals[0]).strip()
    except Exception:
        return s.strip()
    return s.strip()

def normalize_for_matching(s: str):
    s = (s or "").lower()
    s = s.replace("+","plus")
    s = re.sub(r"[^a-z0-9]+"," ", s)
    s = " ".join(s.split())
    return s

def force_to_allowed(pred_raw: str, allowed_labels: list):
    if not pred_raw or not allowed_labels:
        return allowed_labels[0] if allowed_labels else ""
    pred = try_extract_json_label(pred_raw)
    for lab in allowed_labels:
        if pred == lab:
            return lab
    for lab in allowed_labels:
        if pred.lower() == lab.lower():
            return lab
    n_pred = normalize_for_matching(pred)
    for lab in allowed_labels:
        if normalize_for_matching(lab) == n_pred:
            return lab
    for lab in allowed_labels:
        if normalize_for_matching(lab) in n_pred or n_pred in normalize_for_matching(lab):
            return lab
    pred_tokens = set(n_pred.split())
    best = None; best_score = 0
    for lab in allowed_labels:
        lab_tokens = set(normalize_for_matching(lab).split())
        score = len(pred_tokens & lab_tokens)
        if score > best_score:
            best_score = score; best = lab
    if best_score>0 and best:
        return best
    return allowed_labels[0]


@backoff.on_exception(backoff.expo, Exception, max_tries=8, max_time=600)
def invoke_model_with_retry(prompt: str):
    global llm, current_model
    try:
        resp = llm.invoke(prompt)
        return resp.content
    except Exception as e:
        msg = str(e).lower()
        
        if "model" in msg and ("not found" in msg or "does not exist" in msg or "you do not have access" in msg):
            if current_model != FALLBACK_MODEL:
                print(f"[INFO] model error: {e} -> switching to fallback model {FALLBACK_MODEL}")
                llm = ChatGroq(model=FALLBACK_MODEL, groq_api_key=GROQ_API_KEY, temperature=0)
                current_model = FALLBACK_MODEL
       
                raise
 
        raise


def build_batch_prompt(texts, tasks):
    prompt_lines = []
    prompt_lines.append("You are a precise classifier. For each input, choose exactly ONE label corresponding to its task.")
    prompt_lines.append("Return EXACTLY one JSON array (no extra text) with the same number of items as inputs.")
    prompt_lines.append("Each array element must be a JSON object with a single key 'label', e.g. {\"label\":\"Positive\"}.")
    prompt_lines.append("")
    prompt_lines.append("Label sets (exact spellings):")
    prompt_lines.append("SA (Sentiment): " + ", ".join([f'"{l}"' for l in SENT_LABELS]))
    prompt_lines.append("OFFENS (Offensive): " + ", ".join([f'"{l}"' for l in OFF_LABELS]))
    prompt_lines.append("OTHERCAT (Identity): " + ", ".join([f'"{l}"' for l in ID_LABELS]))
    prompt_lines.append("")
    prompt_lines.append("Now classify the following inputs in order. For each item, I label its task.")
    prompt_lines.append("")
    for i,(t,task) in enumerate(zip(texts,tasks), start=1):
        safe_text = t.replace('"""', '\"\"\"')
        prompt_lines.append(f"{i}) Task={task} --- Text:\n\"\"\"{safe_text}\"\"\"\n")
    prompt_lines.append("")
    prompt_lines.append("Return only a single JSON array. Example:")
    prompt_lines.append('[{"label":"Positive"}, {"label":"Negative"}, {"label":"Mixed_feelings"}]')
    prompt_lines.append("")
    return "\n".join(prompt_lines)


rows = df.reset_index(drop=True)
n = len(rows)
preds = [""] * n
raws  = [""] * n

i = 0
calls = 0
while i < n:
    batch_texts = []
    batch_tasks = []
    idxs = []
    for j in range(i, min(n, i + BATCH_SIZE)):
        batch_texts.append(str(rows.at[j, "text"]))
        batch_tasks.append(str(rows.at[j, "task"]))
        idxs.append(j)
    prompt = build_batch_prompt(batch_texts, batch_tasks)
    try:
        raw = invoke_model_with_retry(prompt)
    except Exception as e:
        print(f"[ERROR] Batch call failed at indices {idxs}: {e}")
        # Per-row fallback
        for k, idx in enumerate(idxs):
            single_prompt = build_batch_prompt([batch_texts[k]], [batch_tasks[k]])
            try:
                r = invoke_model_with_retry(single_prompt)
                arr = try_extract_json_array(r)
                if arr and isinstance(arr, list) and len(arr) >= 1:
                    label_raw = arr[0].get("label", None) if isinstance(arr[0], dict) else None
                    if label_raw is None:
                        label_raw = try_extract_json_label(r)
                else:
                    label_raw = try_extract_json_label(r)
            except Exception as e2:
                print(f"[FALLBACK ERROR] single-row failed idx {idx}: {e2}")
                label_raw = ""
            allowed = TASK_LABELS.get(batch_tasks[k], [])
            forced = force_to_allowed(label_raw, allowed) if allowed else label_raw
            preds[idx] = forced
            raws[idx] = label_raw
        i += BATCH_SIZE
        continue

    calls += 1
    arr = try_extract_json_array(raw)
    if arr and isinstance(arr, list) and len(arr) == len(idxs):
        for k, idx in enumerate(idxs):
            el = arr[k]
            if isinstance(el, dict) and "label" in el:
                label_raw = el["label"]
            else:
                label_raw = el if isinstance(el, str) else json.dumps(el)
            allowed = TASK_LABELS.get(batch_tasks[k], [])
            forced = force_to_allowed(label_raw, allowed) if allowed else label_raw
            preds[idx] = forced
            raws[idx]  = label_raw
    else:
        js_objs = re.findall(r"\{.*?\}", raw, flags=re.S)
        if js_objs and len(js_objs) >= len(idxs):
            for k, idx in enumerate(idxs):
                try:
                    obj = json.loads(js_objs[k])
                    label_raw = obj.get("label", None) or obj.get("classification", None) or next(iter(obj.values()), "")
                except Exception:
                    label_raw = try_extract_json_label(js_objs[k])
                allowed = TASK_LABELS.get(batch_tasks[k], [])
                forced = force_to_allowed(label_raw, allowed) if allowed else label_raw
                preds[idx] = forced
                raws[idx]  = label_raw
        else:
            lines = [ln.strip() for ln in raw.splitlines() if ln.strip()]
            candidate_lines = []
            for ln in lines[::-1]:
                if ln.startswith("[") or ln.startswith("]") or "return" in ln.lower():
                    continue
                candidate_lines.append(ln)
                if len(candidate_lines) >= len(idxs):
                    break
            candidate_lines = candidate_lines[::-1]
            for k, idx in enumerate(idxs):
                label_raw = candidate_lines[k] if k < len(candidate_lines) else ""
                allowed = TASK_LABELS.get(batch_tasks[k], [])
                forced = force_to_allowed(label_raw, allowed) if allowed else label_raw
                preds[idx] = forced
                raws[idx]  = label_raw

    time.sleep(PAUSE_BETWEEN_CALLS)
    i += BATCH_SIZE

print(f"Done inference. API calls made: {calls}")


df_out = rows.copy()
df_out["predicted"] = preds
df_out["raw_model_output"] = raws
df_out.to_csv(OUT_CSV, index=False)
print("Saved predictions to", OUT_CSV)


print("\n\n====== METRICS ======\n")
for task_name in ["SA","OFFENS","OTHERCAT"]:
    sub = df_out[df_out["task"] == task_name]
    if len(sub) == 0:
        print(f"Task {task_name}: NO SAMPLES FOUND (skipping).")
        continue
    y_true = sub["label"].astype(str).tolist()
    y_pred = sub["predicted"].astype(str).tolist()
    acc = accuracy_score(y_true, y_pred)
    print(f"\n--- Task: {task_name} (n={len(sub)}) ---")
    print("Accuracy:", round(acc,4))
    print(classification_report(y_true, y_pred, zero_division=0))

print("\nAll done.")


Loaded rows: 1100
Using model: llama-3.1-8b-instant
Done inference. API calls made: 110
Saved predictions to /content/multitask_predictions_batched.csv




--- Task: SA (n=300) ---
Accuracy: 0.5233
                precision    recall  f1-score   support

Mixed_feelings       0.44      0.15      0.22       100
      Negative       0.57      0.67      0.61       100
      Positive       0.51      0.75      0.60       100

      accuracy                           0.52       300
     macro avg       0.51      0.52      0.48       300
  weighted avg       0.51      0.52      0.48       300


--- Task: OFFENS (n=500) ---
Accuracy: 0.278
                                      precision    recall  f1-score   support

                       Not_offensive       0.34      0.44      0.38       100
                           Offensive       0.26      0.18      0.21       100
     Offensive_Targeted_Insult_Group       0.29      0.16      0.21       100
Offensive_Targeted_Insult_Individual       0.25 

In [None]:

!pip install -q langgraph langchain-groq pandas scikit-learn tqdm


import os
import re
import json
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score

from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from typing import TypedDict, Annotated
from langchain_groq import ChatGroq


CSV_PATH = "/content/final_multitask___test.csv"
OUT_CSV  = "/content/predictions_langgraph.csv"

if "GROQ_API_KEY" not in os.environ:
    raise RuntimeError("Set environment variable: os.environ['GROQ_API_KEY']='your-key'")

MODEL = "llama-3.1-8b-instant"


SENT_LABELS = ["Mixed_feelings", "Negative", "Positive"]

OFF_LABELS = [
    "Not_offensive",
    "Offensive",
    "Offensive_Targeted_Insult_Group",
    "Offensive_Targeted_Insult_Individual",
    "Offensive_Untargetede"
]

ID_LABELS = ["Homophobia", "Non-anti-LGBT+ content", "Transphobia"]


TASK_TO_LABELS = {
    "SA": SENT_LABELS,
    "OFFENS": OFF_LABELS,
    "OTHERCAT": ID_LABELS
}


df = pd.read_csv(CSV_PATH)
print("Loaded rows:", len(df))

for c in ("text", "label", "task"):
    if c not in df.columns:
        raise RuntimeError(f"Missing required column: {c}")


llm = ChatGroq(
    model=MODEL,
    groq_api_key=os.environ["GROQ_API_KEY"],
    temperature=0
)


def extract_label(text):
    """Extract label from JSON or fallback to raw."""
    if not isinstance(text, str):
        return text
    m = re.search(r"\{.*?\}", text, flags=re.S)
    if m:
        try:
            obj = json.loads(m.group(0))
            if "label" in obj:
                return obj["label"]
        except:
            pass
    return text.strip()


def build_prompt(task, text):
    allowed = TASK_TO_LABELS[task]
    labels_str = ", ".join([f'"{l}"' for l in allowed])
    return f"""
Classify the following text according to task = {task}.
Allowed labels = [{labels_str}]

Return ONLY JSON:
{{ "label": "<label>" }}

Text:
\"\"\"{text}\"\"\""""


class State(TypedDict):
    messages: Annotated[list, add_messages]
    task: str
    text: str
    prediction: str


def router_node(state: State):
    return {"messages": [f"ROUTE:{state['task']}"]}


def sentiment_agent(state: State):
    prompt = build_prompt("SA", state["text"])
    out = llm.invoke(prompt).content
    label = extract_label(out)
    return {"prediction": label, "messages": [out]}


def offense_agent(state: State):
    prompt = build_prompt("OFFENS", state["text"])
    out = llm.invoke(prompt).content
    label = extract_label(out)
    return {"prediction": label, "messages": [out]}


def othercat_agent(state: State):
    prompt = build_prompt("OTHERCAT", state["text"])
    out = llm.invoke(prompt).content
    label = extract_label(out)
    return {"prediction": label, "messages": [out]}


graph = StateGraph(State)

graph.add_node("router", router_node)
graph.add_node("sentiment", sentiment_agent)
graph.add_node("offense", offense_agent)
graph.add_node("othercat", othercat_agent)

graph.add_edge(START, "router")

graph.add_conditional_edges(
    "router",
    lambda st: st["task"],
    {
        "SA": "sentiment",
        "OFFENS": "offense",
        "OTHERCAT": "othercat"
    }
)

graph.add_edge("sentiment", END)
graph.add_edge("offense", END)
graph.add_edge("othercat", END)

app = graph.compile()
print("Graph built successfully.")


preds = []
for idx, row in tqdm(df.iterrows(), total=len(df)):
    out = app.invoke({
        "messages": [],
        "task": row["task"],
        "text": row["text"]
    })
    preds.append(out["prediction"])

df["predicted"] = preds
df.to_csv(OUT_CSV, index=False)
print("Saved predictions to:", OUT_CSV)


print("\n===== METRICS =====\n")
for task_name in ["SA", "OFFENS", "OTHERCAT"]:
    sub = df[df["task"] == task_name]
    if len(sub)==0:
        continue
    print(f"\n--- Task {task_name} (n={len(sub)}) ---")
    print("Accuracy:", accuracy_score(sub["label"], sub["predicted"]))
    print(classification_report(sub["label"], sub["predicted"], zero_division=0))


Loaded rows: 1100
Graph built successfully.


100%|██████████| 1100/1100 [38:45<00:00,  2.11s/it]

Saved predictions to: /content/predictions_langgraph.csv

===== METRICS =====


--- Task SA (n=300) ---
Accuracy: 0.43333333333333335
                precision    recall  f1-score   support

Mixed_feelings       0.36      0.84      0.51       100
      Negative       0.60      0.21      0.31       100
      Positive       0.74      0.25      0.37       100

      accuracy                           0.43       300
     macro avg       0.57      0.43      0.40       300
  weighted avg       0.57      0.43      0.40       300


--- Task OFFENS (n=500) ---
Accuracy: 0.316
                                      precision    recall  f1-score   support

                       Not_offensive       0.42      0.36      0.39       100
                           Offensive       0.00      0.00      0.00       100
     Offensive_Targeted_Insult_Group       0.25      0.62      0.35       100
Offensive_Targeted_Insult_Individual       0.33      0.29      0.31       100
               Offensive_Untargeted




In [None]:


import os, re, json, time
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
import backoff
from langchain_google_genai import ChatGoogleGenerativeAI


CSV_PATH = "/content/final_multitask___test.csv"     
OUT_CSV  = "/content/multitask_predictions_gemini.csv"

PRIMARY_MODEL  = "gemini-2.5-flash"
FALLBACK_MODEL = "gemini-1.5-flash"

BATCH_SIZE = 10
PAUSE_BETWEEN_CALLS = 0.2
SAMPLE_N = 0  

GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]


SENT_LABELS = ["Mixed_feelings", "Negative", "Positive"]

OFF_LABELS = [
    "Not_offensive",
    "Offensive",
    "Offensive_Targeted_Insult_Group",
    "Offensive_Targeted_Insult_Individual",
    "Offensive_Untargetede"
]

ID_LABELS = ["Homophobia", "Non-anti-LGBT+ content", "Transphobia"]

TASK_LABELS = {
    "SA": SENT_LABELS,
    "OFFENS": OFF_LABELS,
    "OTHERCAT": ID_LABELS
}


df = pd.read_csv(CSV_PATH)
print("Loaded rows:", len(df))

for c in ("text","label","task"):
    if c not in df.columns:
        raise RuntimeError(f"Missing column '{c}'")

if SAMPLE_N > 0:
    df = df.head(SAMPLE_N)


def build_llm(model_name):
    return ChatGoogleGenerativeAI(
        model=model_name,
        google_api_key=GOOGLE_API_KEY,
        temperature=0
    ), model_name

try:
    llm, current_model = build_llm(PRIMARY_MODEL)
    print(f"Using model: {current_model}")
except:
    llm, current_model = build_llm(FALLallback_MODEL)
    print(f"Primary failed. Using fallback: {current_model}")


def try_extract_json_array(s):
    if not isinstance(s, str):
        return None
    m = re.search(r"\[.*\]", s, flags=re.S)
    if not m:
        return None
    try:
        return json.loads(m.group(0))
    except:
        return None

def try_extract_json_label(s):
    if not isinstance(s, str):
        return ""
    m = re.search(r"\{.*?\}", s, flags=re.S)
    if m:
        try:
            obj = json.loads(m.group(0))
            for k in ("label","classification","class"):
                if k in obj:
                    return obj[k]
        except:
            pass
    return s.strip()

def normalize_for_matching(s):
    s = (s or "").lower()
    s = s.replace("+","plus")
    s = re.sub(r"[^a-z0-9]+"," ", s)
    return " ".join(s.split())

def force_to_allowed(pred_raw, allowed):
    if not allowed:
        return pred_raw
    pred = try_extract_json_label(pred_raw)
    if pred in allowed:
        return pred
    for lab in allowed:
        if pred.lower() == lab.lower():
            return lab
    n_pred = normalize_for_matching(pred)
    for lab in allowed:
        if normalize_for_matching(lab) == n_pred:
            return lab
    tokens_pred = set(n_pred.split())
    best=None;score=0
    for lab in allowed:
        s = len(tokens_pred & set(normalize_for_matching(lab).split()))
        if s > score:
            best=lab;score=s
    return best if best else allowed[0]

# ------------------------
# RETRY WRAPPER
# ------------------------
@backoff.on_exception(backoff.expo, Exception, max_tries=8)
def invoke_model(prompt):
    global llm, current_model
    try:
        resp = llm.invoke(prompt)
        return resp.content
    except Exception as e:
        msg=str(e).lower()
        if "not found" in msg or "access" in msg:
            if current_model != FALLBACK_MODEL:
                print("Switching to fallback:", FALLBACK_MODEL)
                llm, current_model = build_llm(FALLBACK_MODEL)
                raise
        raise

# ------------------------
# BUILD BATCH PROMPT
# ------------------------
def build_batch_prompt(texts, tasks):
    out=[]
    out.append("You are a strict classifier. Return ONLY one JSON array.")
    out.append('Each element must look like {"label":"Positive"}.')
    out.append("")
    out.append("Label sets:")
    out.append("SA: " + ", ".join(SENT_LABELS))
    out.append("OFFENS: " + ", ".join(OFF_LABELS))
    out.append("OTHERCAT: " + ", ".join(ID_LABELS))
    out.append("")
    out.append("Classify in order:")
    out.append("")
    for i,(t,tk) in enumerate(zip(texts,tasks), start=1):
        safe=t.replace('"""','\"\"\"')
        out.append(f"{i}) Task={tk}\n\"\"\"{safe}\"\"\"")
        out.append("")
    out.append('Return only this JSON format: [{"label":"Positive"}]')
    return "\n".join(out)

# ------------------------
# BATCH INFERENCE
# ------------------------
rows=df.reset_index(drop=True)
n=len(rows)

preds=[""]*n
raws=[""]*n

i=0
calls=0

while i<n:
    batch_texts=[]
    batch_tasks=[]
    idxs=[]

    for j in range(i, min(n, i+BATCH_SIZE)):
        batch_texts.append(str(rows.at[j,"text"]))
        batch_tasks.append(str(rows.at[j,"task"]))
        idxs.append(j)

    prompt=build_batch_prompt(batch_texts, batch_tasks)

    try:
        raw=invoke_model(prompt)
    except Exception as e:
        print(f"[ERROR] Failed batch {idxs}: {e}")
        i+=BATCH_SIZE
        continue

    calls+=1
    arr=try_extract_json_array(raw)

    if arr and len(arr)==len(idxs):
        for k,idx in enumerate(idxs):
            raw_lab=arr[k].get("label","")
            forced=force_to_allowed(raw_lab, TASK_LABELS.get(batch_tasks[k], []))
            preds[idx]=forced
            raws[idx]=raw_lab
    else:
        print("[WARN] JSON parse fallback")
        for k,idx in enumerate(idxs):
            forced=force_to_allowed(raw, TASK_LABELS.get(batch_tasks[k], []))
            preds[idx]=forced
            raws[idx]=raw

    time.sleep(PAUSE_BETWEEN_CALLS)
    i+=BATCH_SIZE

print("Inference complete. Calls:", calls)


df_out=rows.copy()
df_out["predicted"]=preds
df_out["raw_output"]=raws
df_out.to_csv(OUT_CSV, index=False)
print("Saved to:", OUT_CSV)


print("\n===== METRICS =====")
for task in ["SA","OFFENS","OTHERCAT"]:
    sub=df_out[df_out["task"]==task]
    if len(sub)==0:
        continue
    y_true=sub["label"].astype(str).tolist()
    y_pred=sub["predicted"].astype(str).tolist()
    print(f"\n--- {task} (n={len(sub)}) ---")
    print("Accuracy:", round(accuracy_score(y_true,y_pred),4))
    print(classification_report(y_true,y_pred,zero_division=0))

print("\nDone.")


Loaded rows: 1100
Using model: gemini-2.5-flash
[WARN] JSON parse fallback
Inference complete. Calls: 110
Saved to: /content/multitask_predictions_gemini.csv

===== METRICS =====

--- SA (n=300) ---
Accuracy: 0.5933
                precision    recall  f1-score   support

Mixed_feelings       0.55      0.21      0.30       100
      Negative       0.58      0.70      0.64       100
      Positive       0.61      0.87      0.72       100

      accuracy                           0.59       300
     macro avg       0.58      0.59      0.55       300
  weighted avg       0.58      0.59      0.55       300


--- OFFENS (n=500) ---
Accuracy: 0.33
                                      precision    recall  f1-score   support

                       Not_offensive       0.34      0.94      0.50       100
                           Offensive       0.24      0.04      0.07       100
     Offensive_Targeted_Insult_Group       0.44      0.31      0.36       100
Offensive_Targeted_Insult_Individual 

In [None]:


import os, re, json, time
import pandas as pd
from tqdm import tqdm
from sklearn.metrics import classification_report, accuracy_score
import backoff
from langchain_google_genai import ChatGoogleGenerativeAI


TRAIN_CSV = "/content/final_multitask___train.csv"
TEST_CSV  = "/content/final_multitask___test.csv"
OUT_CSV   = "/content/multitask_predictions_fewshot.csv"



PRIMARY_MODEL  = "gemini-2.5-flash"
FALLBACK_MODEL = "gemini-1.5-flash"

BATCH_SIZE = 10
PAUSE = 0.2

GOOGLE_API_KEY = os.environ["GOOGLE_API_KEY"]


train_df = pd.read_csv(TRAIN_CSV)
test_df  = pd.read_csv(TEST_CSV)

print("Train rows:", len(train_df))
print("Test rows :", len(test_df))


SENT_LABELS = ["Mixed_feelings", "Negative", "Positive"]

OFF_LABELS = [
    "Not_offensive",
    "Offensive",
    "Offensive_Targeted_Insult_Group",
    "Offensive_Targeted_Insult_Individual",
    "Offensive_Untargetede"
]

ID_LABELS = ["Homophobia", "Non-anti-LGBT+ content", "Transphobia"]

TASK_LABELS = {
    "SA": SENT_LABELS,
    "OFFENS": OFF_LABELS,
    "OTHERCAT": ID_LABELS
}


fewshot = {task:{} for task in TASK_LABELS}

for task, labels in TASK_LABELS.items():
    for label in labels:
        examples = train_df[(train_df["task"]==task) & (train_df["label"]==label)].head(3)
        fewshot[task][label] = examples["text"].tolist()

print("Few-shot examples extracted.")


def build_llm(model):
    return ChatGoogleGenerativeAI(
        model=model,
        google_api_key=GOOGLE_API_KEY,
        temperature=0
    ), model

try:
    llm, current_model = build_llm(PRIMARY_MODEL)
    print("Using model:", current_model)
except:
    llm, current_model = build_llm(FALLBACK_MODEL)
    print("Using fallback:", current_model)


def build_prompt(texts, tasks):
    out = []
    out.append("You are a STRICT multilabel classifier.")
    out.append("Return ONLY one JSON array.")
    out.append('Each element must be {"label":"LABEL"}.')

    out.append("\n### LABEL SETS")
    out.append("SA: " + ", ".join(SENT_LABELS))
    out.append("OFFENS: " + ", ".join(OFF_LABELS))
    out.append("OTHERCAT: " + ", ".join(ID_LABELS))

    out.append("\n### FEW-SHOT EXAMPLES")


    for task in tasks:
        out.append(f"\nTask={task} examples:")
        for lab, exs in fewshot[task].items():
            for e in exs:
                out.append(f"Example Text: \"{e}\"\nLabel: {lab}\n")

    out.append("\n### NOW CLASSIFY THESE:")
    for i, (txt, tk) in enumerate(zip(texts, tasks), start=1):
        safe = txt.replace('"""', "'")
        out.append(f"{i}) Task={tk}\n\"\"\"{safe}\"\"\"")

    out.append('\nReturn ONLY JSON like: [{"label":"Positive"}]')
    return "\n".join(out)


def extract_json_array(s):
    m = re.search(r"\[.*\]", s, flags=re.S)
    if not m: return None
    try: return json.loads(m.group(0))
    except: return None

@backoff.on_exception(backoff.expo, Exception, max_tries=6)
def ask_llm(prompt):
    global llm, current_model
    try:
        r = llm.invoke(prompt)
        return r.content
    except Exception as e:
        msg = str(e).lower()
        if "not found" in msg or "access" in msg:
            if current_model != FALLBACK_MODEL:
                llm, current_model = build_llm(FALLBACK_MODEL)
                raise
        raise


rows = test_df.reset_index(drop=True)
n = len(rows)

preds = [""] * n

i = 0
calls = 0

while i < n:
    batch_texts = []
    batch_tasks = []
    idxs = []

    for j in range(i, min(n, i + BATCH_SIZE)):
        batch_texts.append(str(rows.at[j,"text"]))
        batch_tasks.append(str(rows.at[j,"task"]))
        idxs.append(j)

    prompt = build_prompt(batch_texts, batch_tasks)
    raw = ask_llm(prompt)

    arr = extract_json_array(raw)
    if not arr or len(arr) != len(idxs):
        print("⚠ Fallback parse")
        for k, idx in enumerate(idxs):
            preds[idx] = "Unknown"
    else:
        for k, idx in enumerate(idxs):
            preds[idx] = arr[k].get("label","Unknown")

    i += BATCH_SIZE
    time.sleep(PAUSE)
    calls += 1

print("Done. Total calls:", calls)


test_df["predicted"] = preds
test_df.to_csv(OUT_CSV, index=False)
print("Saved predictions:", OUT_CSV)

print("\n===== METRICS =====")
for task in ["SA","OFFENS","OTHERCAT"]:
    sub = test_df[test_df["task"] == task]
    print(f"\n--- {task} ---")
    print(classification_report(sub["label"], sub["predicted"], zero_division=0))

print("\nAll Done.")


Train rows: 28600
Test rows : 1100
Few-shot examples extracted.
Using model: gemini-2.5-flash
Done. Total calls: 110
Saved predictions: /content/multitask_predictions_fewshot.csv

===== METRICS =====

--- SA ---
                precision    recall  f1-score   support

Mixed_feelings       0.52      0.34      0.41       100
      Negative       0.64      0.72      0.68       100
      Positive       0.66      0.81      0.73       100

      accuracy                           0.62       300
     macro avg       0.61      0.62      0.61       300
  weighted avg       0.61      0.62      0.61       300


--- OFFENS ---
                                      precision    recall  f1-score   support

                       Not_offensive       0.40      0.87      0.55       100
                           Offensive       0.61      0.19      0.29       100
     Offensive_Targeted_Insult_Group       0.45      0.39      0.42       100
Offensive_Targeted_Insult_Individual       0.42      0.59      0