## 0) Environment setup (install + Claude client)

In [1]:
!pip -q install anthropic pandas scikit-learn

import os
import pandas as pd

# --- Claude API client setup ---
# Preferred in Colab: set a Secret named ANTHROPIC_API_KEY
try:
    from google.colab import userdata
    api_key = userdata.get("ANTHROPIC_API_KEY")
except Exception:
    api_key = None

if not api_key:
    api_key = os.environ.get("ANTHROPIC_API_KEY", "")
    if not api_key:
        raise ValueError("Missing ANTHROPIC_API_KEY. Add it to Colab Secrets or set it as an environment variable.")

from anthropic import Anthropic
client = Anthropic(api_key=api_key)

print("Claude client is ready.")

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/390.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m143.4/390.3 kB[0m [31m4.2 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m389.1/390.3 kB[0m [31m4.7 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m390.3/390.3 kB[0m [31m2.9 MB/s[0m eta [36m0:00:00[0m
[?25hClaude client is ready.


## 1) Few-shot learning example (TF–IDF retrieval + Claude classification)

In [2]:
from google.colab import files

uploaded = files.upload()  # Upload your CSV file (employee reviews dataset)
csv_name = next(iter(uploaded.keys()))
print("Uploaded:", csv_name)

df = pd.read_csv(csv_name)
df.head()

Saving employee_reviews.csv to employee_reviews (1).csv
Uploaded: employee_reviews (1).csv


Unnamed: 0,id,department,role,salary_sar_monthly,age,review_period,manager_rating_1_5,peer_rating_1_5,engagement_1_5,promotion_recommendation,review_summary
0,E0001,Customer Operations,Customer Support Specialist,8500,26,2025-H2,4.2,4.1,4.3,No,Consistently resolves billing tickets on first...
1,E0002,Customer Operations,Customer Support Specialist,7800,24,2025-H2,3.6,3.8,3.9,No,"Good tone and empathy; speed is improving, but..."
2,E0003,Customer Operations,Senior Customer Support Specialist,10200,30,2025-H2,4.5,4.4,4.2,Yes,Strong ownership of complex device cases; ment...
3,E0004,Customer Operations,Team Leader - Support,14500,34,2025-H2,4.1,4.0,4.1,No,Maintains service levels under pressure; shoul...
4,E0005,Customer Operations,Quality Assurance Analyst,12000,29,2025-H2,4.3,4.2,4.4,Yes,High attention to detail; feedback to agents i...


In [3]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity

text_col = "review_summary"

vectorizer = TfidfVectorizer(stop_words="english")
X = vectorizer.fit_transform(df[text_col].astype(str).tolist())

def retrieve_examples(query, k=6):
    qv = vectorizer.transform([query])
    sims = cosine_similarity(qv, X).ravel()
    idx = sims.argsort()[::-1][:k]
    return df.iloc[idx][[
        "id",
        "department",
        "role",
        "salary_sar_monthly",
        "age",
        "manager_rating_1_5",
        "promotion_recommendation",
        text_col
    ]].to_dict("records")

In [4]:
def predict_promotion(review_text, k=6, model="claude-3-5-sonnet-latest"):
    examples = retrieve_examples(review_text, k=k)

    examples_block = "\n\n".join(
        [f"EXAMPLE {i+1}\n"
         f"Role: {e['role']} | Dept: {e['department']} | Age: {e['age']} | Salary(SAR): {e['salary_sar_monthly']}\n"
         f"Manager rating: {e['manager_rating_1_5']} | Promotion: {e['promotion_recommendation']}\n"
         f"Review: {e['review_summary']}"
         for i, e in enumerate(examples)]
    )

    system = (
        "You are an HR performance reviewer. "
        "Use only the patterns implied by the provided examples. "
        "Return a single label only: Yes or No."
    )

    user = (
        "Here are similar historical labeled examples:\n\n"
        f"{examples_block}\n\n"
        "Now label this new review text for promotion recommendation:\n"
        f"{review_text}\n\n"
        "Answer with Yes or No only."
    )

    resp = client.messages.create(
        model="claude-sonnet-4-5",
        max_tokens=10,
        temperature=0.0,
        system=system,
        messages=[{"role": "user", "content": user}]
    )
    return resp.content[0].text.strip()

test_review = "Consistently exceeds targets, mentors peers, and drives measurable process improvements; trusted to lead incident reviews."
print(predict_promotion(test_review))

Yes


## 2) Tool calling example (compute an HR statistic via a Python tool)

In [5]:
import pandas as pd

data = {
    "employee_id": ["E01", "E02", "E03", "E04", "E05"],
    "absent_days": [2, 7, 1, 10, 3]
}
df_abs = pd.DataFrame(data)
df_abs

Unnamed: 0,employee_id,absent_days
0,E01,2
1,E02,7
2,E03,1
3,E04,10
4,E05,3


In [6]:
def calculate_average_absence(absent_days: list) -> float:
    return sum(absent_days) / len(absent_days)

In [7]:
tools = [
    {
        "name": "calculate_average_absence",
        "description": "Calculate the average number of absence days for employees",
        "input_schema": {
            "type": "object",
            "properties": {
                "absent_days": {"type": "array", "items": {"type": "number"}}
            },
            "required": ["absent_days"]
        }
    }
]

In [8]:
system_message = (
    "You are an HR analytics assistant. "
    "Decide when a statistical calculation is required and use tools when appropriate. "
    "After receiving tool results, explain the implications in HR terms."
)

user_message = (
    "Here are employee absence days:\n"
    f"{df_abs['absent_days'].tolist()}\n\n"
    "Is the average absence level acceptable if the policy threshold is 5 days?"
)

response = client.messages.create(
    model="claude-sonnet-4-5",
    max_tokens=500,
    temperature=0.2,
    system=system_message,
    messages=[{"role": "user", "content": user_message}],
    tools=tools
)

response

Message(id='msg_01Y94B9tyegYMAqJtPGRGLcK', content=[ToolUseBlock(id='toolu_01QF3V4Z1oQVVDbw8aMH9gas', input={'absent_days': [2, 7, 1, 10, 3]}, name='calculate_average_absence', type='tool_use')], model='claude-sonnet-4-5-20250929', role='assistant', stop_reason='tool_use', stop_sequence=None, type='message', usage=Usage(cache_creation=CacheCreation(ephemeral_1h_input_tokens=0, ephemeral_5m_input_tokens=0), cache_creation_input_tokens=0, cache_read_input_tokens=0, input_tokens=649, output_tokens=71, server_tool_use=None, service_tier='standard'))

In [9]:
message = response.content[0]

if message.type == "tool_use":
    tool_input = message.input
    avg_absence = calculate_average_absence(tool_input["absent_days"])

    final_response = client.messages.create(
        model="claude-sonnet-4-5",
        max_tokens=400,
        temperature=0.2,
        messages=[
            {"role": "user", "content": user_message},
            {"role": "assistant", "content": response.content},
            {
                "role": "user",
                "content": [
                    {
                        "type": "tool_result",
                        "tool_use_id": message.id,
                        "content": str(avg_absence)
                    }
                ]
            }
        ]
    )

    print(final_response.content[0].text)
else:
    print(getattr(message, "text", str(message)))

**Yes, the average absence level is acceptable.**

- **Average absence:** 4.6 days
- **Policy threshold:** 5 days

Since 4.6 days is below the 5-day threshold, the team's average absence level meets the policy requirement.

**However, note:** While the average is acceptable, one employee has 10 absence days, which is double the threshold. You may want to review individual cases, as the average can mask outliers.


## 3) Retrieval-Augmented Generation (RAG) example (document upload → chunking → FAISS retrieval → grounded answers)

In [10]:
from google.colab import files

uploaded = files.upload()  # Upload a .txt, .pdf, or .docx document
file_path = next(iter(uploaded.keys()))
print("Uploaded:", file_path)

Saving Survival Analysis.pdf to Survival Analysis.pdf
Uploaded: Survival Analysis.pdf


In [11]:
from pathlib import Path

!pip -q install pypdf python-docx

def read_txt(path: str) -> str:
    return Path(path).read_text(encoding="utf-8", errors="ignore")

def read_pdf(path: str) -> str:
    from pypdf import PdfReader
    reader = PdfReader(path)
    pages = []
    for p in reader.pages:
        pages.append(p.extract_text() or "")
    return "\n".join(pages)

def read_docx(path: str) -> str:
    import docx
    d = docx.Document(path)
    return "\n".join([para.text for para in d.paragraphs])

def load_document(path: str) -> str:
    ext = Path(path).suffix.lower()
    if ext == ".txt":
        return read_txt(path)
    if ext == ".pdf":
        return read_pdf(path)
    if ext == ".docx":
        return read_docx(path)
    raise ValueError(f"Unsupported file type: {ext}. Use .txt, .pdf, or .docx")

doc_text = load_document(file_path)
doc_text[:800]

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/329.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/329.0 kB[0m [31m3.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m327.7/329.0 kB[0m [31m5.1 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m329.0/329.0 kB[0m [31m4.4 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/253.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m253.0/253.0 kB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[?25h

'الفصل الثاني عشر عشر: استخدامتحليل البقاء (Survival Analysis )\nللتنبؤ بمدة بقاء الموظفين \n"تحليل البقاء لا يقتصر على التنبؤ بموعد وقوع الحدث، بل يتعمّق في تفسير أسبابه والعوامل التي  \nتؤخره أو تعجّله ". \n— ديفيد كلاينباوم، مؤلف كتاب (تحليل البقاء باستخدام نموذج كوكس ) \nمقدمة عن تحليل البقاء  \nتحليل البقاء (Survival Analysis) هو طريقة إحصائية تُستخدم لدراسة الوقت الذي يستغرقه حدوث شيء ما، مثل استقالة \nالموظف من العمل. ما يميز هذا التحليل أنه يستطيع التعامل مع الحالات التي لم يحدث فيها الحدث بعد، مثل الموظف الذي \nما زال يعمل عند انتهاء الدراسة، وهي ما تُسمى بالرقابة ( Censoring). هذه الطريقة تعطي فهمًا أدق للبيانات من الطرق \nالعادية مثل المتوسط أو الانحدار، لأنها تأخذ بعين الاعتبار الزمن واحتمالية وقوع الحدث (Jin et al., 2020 .) \nفي مجال الموارد البشريةHR)،يكتسب تحليل البقاء أهمية متزايدة، ل'

In [12]:
from typing import List

def chunk_text(text: str, chunk_size: int = 900, overlap: int = 150) -> List[str]:
    text = text.replace("\r\n", "\n").replace("\r", "\n")
    text = "\n".join([line.strip() for line in text.split("\n")]).strip()

    chunks = []
    start = 0
    n = len(text)

    while start < n:
        end = min(start + chunk_size, n)
        chunk = text[start:end].strip()
        if chunk:
            chunks.append(chunk)
        if end == n:
            break
        start = max(0, end - overlap)

    return chunks

chunks = chunk_text(doc_text, chunk_size=900, overlap=150)
print("Chunks:", len(chunks))
print(chunks[0][:300])

Chunks: 34
الفصل الثاني عشر عشر: استخدامتحليل البقاء (Survival Analysis )
للتنبؤ بمدة بقاء الموظفين
"تحليل البقاء لا يقتصر على التنبؤ بموعد وقوع الحدث، بل يتعمّق في تفسير أسبابه والعوامل التي
تؤخره أو تعجّله ".
— ديفيد كلاينباوم، مؤلف كتاب (تحليل البقاء باستخدام نموذج كوكس )
مقدمة عن تحليل البقاء
تحليل البقاء 


In [13]:
!pip -q install faiss-cpu sentence-transformers

import numpy as np
import faiss
from sentence_transformers import SentenceTransformer

embedder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")

chunk_vectors = embedder.encode(
    chunks,
    batch_size=64,
    show_progress_bar=True,
    convert_to_numpy=True,
    normalize_embeddings=True
).astype("float32")

dim = chunk_vectors.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(chunk_vectors)

print("FAISS index size:", index.ntotal)

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.8/23.8 MB[0m [31m50.0 MB/s[0m eta [36m0:00:00[0m
[?25h

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

FAISS index size: 34


In [14]:
def retrieve(query: str, top_k: int = 5):
    q_vec = embedder.encode([query], normalize_embeddings=True, convert_to_numpy=True).astype("float32")
    scores, ids = index.search(q_vec, top_k)

    results = []
    for score, idx in zip(scores[0], ids[0]):
        if idx == -1:
            continue
        results.append((float(score), chunks[int(idx)]))
    return results

In [15]:
SYSTEM = (
    "You are an HR assistant answering questions using ONLY the provided document excerpts.\n"
    "Rules:\n"
    "- Use only the excerpts as factual grounding.\n"
    "- If the answer is not contained in the excerpts, say you don't have enough information from the file.\n"
    "- Keep the answer clear and professional.\n"
    "- When useful, quote short phrases from the excerpts to justify your answer.\n"
)

def answer_with_rag(question: str, top_k: int = 5, model: str = "claude-3-5-sonnet-latest"):
    hits = retrieve(question, top_k=top_k)

    context_blocks = []
    for i, (score, text) in enumerate(hits, start=1):
        context_blocks.append(f"[EXCERPT {i} | score={score:.3f}]\n{text}")

    context = "\n\n".join(context_blocks) if context_blocks else "No excerpts retrieved."

    user_message = (
        "Document excerpts:\n"
        f"{context}\n\n"
        "Question:\n"
        f"{question}\n\n"
        "Answer (grounded in the excerpts):"
    )

    resp = client.messages.create(
        model="claude-sonnet-4-5",
        max_tokens=800,
        temperature=0.2,
        system=SYSTEM,
        messages=[{"role": "user", "content": user_message}],
    )
    return resp.content[0].text, hits

In [16]:
question = "What techniques described in the document?"
answer, hits = answer_with_rag(question, top_k=5)
print(answer)

Based on the provided excerpts, the document describes several **survival analysis techniques** in the context of employee attrition (استنزاف الموظفين):

1. **Kaplan-Meier Estimator (Kaplan-Meier Fitter)**
   - Described as "one of the most commonly used tools in survival analysis"
   - Used to calculate and represent survival curves showing the probability of employees remaining over time

2. **Log-Rank Test**
   - Used for statistical comparison between groups
   - The document shows an example comparing survival curves between males and females, with a p-value of 0.4189, indicating no statistically significant difference between genders

3. **Cox Proportional Hazards Model (Cox PHFitter)**
   - Described as "one of the most commonly used models in survival analysis"
   - Used to analyze the effect of variables on the risk of employee attrition

The document also mentions fundamental concepts like:
- **Censoring (الرقابة)** - handling incomplete data
- **Hazard rate (معدل الخطر)** - 

In [17]:
!pip -q install gradio

import gradio as gr

def chat_fn(question, top_k):
    answer, hits = answer_with_rag(question, top_k=int(top_k))
    sources = "\n\n".join([f"EXCERPT {i+1} (score={s:.3f}): {t[:350]}..." for i, (s, t) in enumerate(hits)])
    return answer, sources

with gr.Blocks() as demo:
    gr.Markdown(
        "## Simple HR RAG (Claude API)\n"
        "Upload your HR policy / job description / handbook first (done earlier in the notebook), then ask questions."
    )
    q = gr.Textbox(label="Question", placeholder="ما فائدة Log-Rank Test؟")
    k = gr.Slider(1, 10, value=5, step=1, label="Top-K excerpts")
    out_answer = gr.Textbox(label="Claude Answer", lines=8)
    out_sources = gr.Textbox(label="Retrieved excerpts (snippets)", lines=10)
    btn = gr.Button("Ask")
    btn.click(chat_fn, inputs=[q, k], outputs=[out_answer, out_sources])

demo.launch(share=True)

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://44915e53634a877c8e.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)




## 4) Agentic workflow with CrewAI (multi-agent hiring decision memo)

In [None]:
!pip -q install crewai

In [None]:
def claude_complete(prompt: str,
                    model: str = "claude-sonnet-4-5",
                    max_tokens: int = 800,
                    temperature: float = 0.2) -> str:
    resp = client.messages.create(
        model=model,
        max_tokens=max_tokens,
        temperature=temperature,
        messages=[{"role": "user", "content": prompt}]
    )
    return resp.content[0].text

In [None]:
class ClaudeToolingLLM:
    def __init__(self, model="claude-sonnet-4-5", temperature=0.2, max_tokens=900):
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens

    def __call__(self, prompt: str) -> str:
        return claude_complete(
            prompt=prompt,
            model=self.model,
            max_tokens=self.max_tokens,
            temperature=self.temperature
        )

llm = ClaudeToolingLLM(model="claude-sonnet-4-5", temperature=0.2, max_tokens=900)

In [None]:
job_description = """Role: HR Data Analyst
Key requirements:
- SQL, Python (pandas), dashboarding (Power BI or similar)
- Basic statistics (hypothesis testing), clear communication
- 2+ years experience in analytics
"""

candidate_cv = """Candidate: Sara A.
Experience: 3 years as People Analytics Specialist
Skills: Python (pandas, numpy), SQL (PostgreSQL), Power BI, A/B testing
Projects: Attrition dashboard, headcount forecasting, survey analysis
Education: BSc in Statistics
"""

interview_notes = """- Communication is strong, explains assumptions clearly.
- Needs more depth in experimental design edge cases.
- Demonstrated good data cleaning practices.
"""

In [None]:
from crewai import Agent

screener = Agent(
    role="HR Screener",
    goal="Match job requirements to candidate evidence and identify strengths/gaps concisely.",
    backstory="You are a meticulous HR analyst who only uses provided text as evidence.",
    llm=llm,
    verbose=True
)

risk_fairness = Agent(
    role="Risk & Fairness Reviewer",
    goal="Flag missing info, risks, and potential bias; suggest clarifying interview questions.",
    backstory="You ensure hiring decisions are fair, evidence-based, and defensible.",
    llm=llm,
    verbose=True
)

writer = Agent(
    role="Hiring Recommendation Writer",
    goal="Write a short professional hiring memo grounded only in prior agent outputs.",
    backstory="You write clear hiring memos for leadership.",
    llm=llm,
    verbose=True
)

In [None]:
from crewai import Task

task1 = Task(
    description=f"""
You will receive a job description, a candidate CV, and interview notes.
Extract a structured assessment:
1) Requirements met (with brief evidence quotes)
2) Gaps/uncertainties
3) Overall fit in one sentence
Inputs:
JOB DESCRIPTION:
{job_description}
CANDIDATE CV:
{candidate_cv}
INTERVIEW NOTES:
{interview_notes}
""",
    agent=screener,
    expected_output="A short structured assessment with evidence and gaps."
)

task2 = Task(
    description="""
Review the prior assessment and identify:
1) Any leaps of logic or unsupported claims
2) Potential bias risks (e.g., relying on non-job-related factors)
3) What information is missing to finalize a decision
4) 5 targeted follow-up interview questions
""",
    agent=risk_fairness,
    expected_output="Risk/fairness review + missing info + 5 questions."
)

task3 = Task(
    description="""
Using ONLY the outputs of the two previous agents, draft a final recommendation memo:
- Decision: Proceed / Hold / Reject
- 3 concise justifications
- Next steps (if proceeding)
Keep it under 180 words.
""",
    agent=writer,
    expected_output="A short memo with decision, justifications, and next steps."
)

In [None]:
from crewai import Crew, Process

crew = Crew(
    agents=[screener, risk_fairness, writer],
    tasks=[task1, task2, task3],
    process=Process.sequential,
    verbose=True
)

result = crew.kickoff()
result

## 5) Fine-tuning (Hugging Face Transformers) — binary text classification on HR review summaries

In [None]:
!pip -q install transformers datasets evaluate accelerate

In [None]:
from google.colab import files
uploaded = files.upload()
csv_name = next(iter(uploaded.keys()))
print("Uploaded:", csv_name)

df_ft = pd.read_csv(csv_name)  # e.g., employee_reviews.csv
df_ft = df_ft[["review_summary", "promotion_recommendation"]].dropna()

label_map = {"No": 0, "Yes": 1}
df_ft["label"] = df_ft["promotion_recommendation"].map(label_map)

df_ft.head()

In [None]:
from datasets import Dataset

ds = Dataset.from_pandas(df_ft[["review_summary", "label"]])
ds = ds.train_test_split(test_size=0.2, seed=42)
train_ds = ds["train"]
test_ds = ds["test"]

train_ds, test_ds

In [None]:
from transformers import AutoTokenizer

model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize(batch):
    return tokenizer(batch["review_summary"], truncation=True, padding="max_length", max_length=256)

train_tok = train_ds.map(tokenize, batched=True)
test_tok = test_ds.map(tokenize, batched=True)

train_tok = train_tok.remove_columns(["review_summary"])
test_tok = test_tok.remove_columns(["review_summary"])

train_tok.set_format("torch")
test_tok.set_format("torch")

train_tok[0]

In [None]:
import evaluate
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer

accuracy = evaluate.load("accuracy")
f1 = evaluate.load("f1")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy.compute(predictions=preds, references=labels)["accuracy"],
        "f1": f1.compute(predictions=preds, references=labels, average="binary")["f1"],
    }

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

args = TrainingArguments(
    output_dir="hf_small_ft",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=8,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="no",
    logging_steps=10,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_tok,
    eval_dataset=test_tok,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [None]:
trainer.train()
trainer.evaluate()