In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [47]:
import json, os, time, uuid, concurrent.futures as cf
import polars as pl
from concurrent.futures import FIRST_COMPLETED
from typing import Any, Dict, List

from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage
from azure.core.credentials import AzureKeyCredential
from sklearn.metrics import accuracy_score, classification_report

In [3]:
CHECKPOINT_PATH = "../data/zero_shot_llm_ckpt.parquet"
FINAL_PATH = "../data/zero_shot_with_llm.parquet"   # optional merged output
TEXT_COL = "text"
KEY_COL = "row_id"
MAX_WORKERS = 8
MAX_RETRIES = 3
BASE_BACKOFF = 1.0
IN_FLIGHT_FACTOR = 4               # max queued = MAX_WORKERS * IN_FLIGHT_FACTOR
FLUSH_EVERY = 100 

In [24]:
train_df = pl.read_parquet('../data/train_yelp_review.parquet')
test_df = pl.read_parquet('../data/test_yelp_review.parquet')

In [25]:
working_df = test_df.group_by('label').agg(pl.all().sample(n=1000, with_replacement=False))
df_obj = {"label": [], "text": []}
few_shot_examples = ""
for row in working_df.iter_rows():
    for text in row[1]:
        df_obj["label"].append(row[0] + 1)
        df_obj["text"].append(text)
df = pl.DataFrame(df_obj)

In [26]:
df.head()

label,text
i64,str
4,"""So happy to discover that Edin…"
4,"""We avoided the Saturday Badger…"
4,"""The food is good...but..I thin…"
4,"""Gets the job done .had a Weiss…"
4,"""The Fountains of Bellagio. Som…"


In [27]:
# For Serverless API or Managed Compute endpoints
from os import getenv


client = ChatCompletionsClient(
    endpoint="https://portfolio-resource.services.ai.azure.com/models",
    credential=AzureKeyCredential(getenv("AZURE_AI_INFERENCE_KEY")),
    api_version="2024-05-01-preview"
)
model_name = "llama-31_8b"

In [28]:
SYSTEM_PROMPT = """You are a strict Yelp review classifier.

Rubric:
1 = very negative, would not return
2 = negative with some redeeming detail
3 = mixed/neutral
4 = positive with minor issues
5 = very positive, enthusiastic

Respond ONLY in valid JSON with:
- "stars": integer in [1,2,3,4,5]
- "explanation": one concise sentence (no step-by-step)

Format:
{{"stars": <int>, "explanation": "<1 sentence>"}}

{examples}
"""

# SYSTEM_PROMPT = """You are a precise sentiment rating assistant.
# Given a single review, decide the star rating from 1-5 (1=very negative, 3=neutral/mixed, 5=very positive).
# Return ONLY a strict JSON object with keys 'stars' (an integer 1-5) and 'explanation' (a brief string).
# Do not include any extra keys or text outside JSON.
# Return strict JSON: {{\"stars\": <int>, \"explanation\": \"<short reason>\"}}.
# """

ZERO_SHOT_USER_TMPL = """Review:
\"\"\"{review_text}\"\"\""""

In [29]:
# system_prompt = SYSTEM_PROMPT + "\nHere are the examples of reviews and their ratings: {examples}"

In [30]:
# system_prompt = system_prompt.format_map({'examples': "\n\n".join([f"- Review: \"\"\"{example['review_text']}\"\"\"\nRating: {example['rating']}\nExplanation: {example['explanation']}" for example in few_shot_examples])})
# few_shot_example = ZERO_SHOT_USER_TMPL.format_map({'review_text': first_review['text'].item()})

# Testing on actual information

In [31]:
if KEY_COL not in df.columns:
    df = df.with_row_index(name=KEY_COL)

In [32]:
schema = {KEY_COL: pl.Int64, "zero_star": pl.Int64, "zero_explanation": pl.Utf8, "few_star": pl.Int64, "few_explanation": pl.Utf8}

In [33]:
if os.path.exists(CHECKPOINT_PATH):
    ckpt = pl.read_parquet(CHECKPOINT_PATH).select(
        pl.col(KEY_COL).cast(pl.Int64),
        pl.col("zero_star").cast(pl.Int64),
        pl.col("zero_explanation").cast(pl.Utf8),
        pl.col("few_star").cast(pl.Int64),
        pl.col("few_explanation").cast(pl.Utf8),
    )
else:
    ckpt = pl.DataFrame({k: pl.Series([], dtype=t) for k, t in schema.items()}, schema=schema)

In [34]:
seen_ids = set(ckpt[KEY_COL].to_list()) if ckpt.height else set()
work_df = df.filter(~pl.col(KEY_COL).is_in(list(seen_ids))).select([KEY_COL, TEXT_COL])

In [35]:
work_df.shape

(5000, 2)

In [36]:
print(f"Total rows: {df.height} | Already done: {len(seen_ids)} | Remaining: {work_df.height}")
if work_df.is_empty():
    # (Optional) produce merged DF and exit
    zero_shot_df = df.join(ckpt, on=KEY_COL, how="left")
    zero_shot_df.write_parquet(FINAL_PATH)
    print(f"No remaining rows. Merged output saved to {FINAL_PATH}")
    # You still have zero_shot_df in memory
    # return/exit if in a script

Total rows: 5000 | Already done: 0 | Remaining: 5000


In [37]:
def get_rating(text: str, system_prompt: str, user: str):
    try:
        response = client.complete(
            messages=[
                SystemMessage(content=system_prompt),
                UserMessage(content=user.format_map({'review_text': text})),
            ],
            max_tokens=128,
            temperature=0.1,
            top_p=0.95,
            model=model_name,
            response_format="json_object"
        )
        output = response.choices[0].message.content
        json_output = json.loads(output)
        return {"llm_stars": int(json_output['stars']), "llm_explanation": json_output['explanation']}
    except Exception as ex:
        print(str(ex))
        return {"llm_stars": 0, "llm_explanation": ""}


In [38]:
def call_with_retry(row_id: int, text: str, zero_shot_system: str, few_shot_system: str) -> Dict[str, Any]:
    for attempt in range(MAX_RETRIES):
        try:
            zerp_shot_out = get_rating(text, zero_shot_system, ZERO_SHOT_USER_TMPL)
            few_shot_out = get_rating(text, few_shot_system, ZERO_SHOT_USER_TMPL)
            zerp_shot_out[KEY_COL] = row_id
            few_shot_out[KEY_COL] = row_id
            return zerp_shot_out, few_shot_out
        except Exception as e:
            sleep_s = BASE_BACKOFF * (2 ** attempt) * (1 + 0.15 * attempt)
            time.sleep(sleep_s)
            if attempt == MAX_RETRIES - 1:
                attemptt_response = {
                    KEY_COL: row_id,
                    "llm_stars": 3,  # fallback neutral
                    "llm_explanation": f"fallback after error: {type(e).__name__}"
                }
                return attemptt_response, attemptt_response

In [39]:
def atomic_write_parquet(df: pl.DataFrame, path: str):
    tmp = f"{path}.tmp.{uuid.uuid4().hex}.parquet"
    df.write_parquet(tmp)
    os.replace(tmp, path)  

In [40]:
few_shot_df = train_df.group_by('label').agg(pl.all().sample(n=2, with_replacement=False))
few_shot_examples = ""
for row in few_shot_df.iter_rows():
    for text in row[1]:
        few_shot_examples += f"- Review: \"\"\"{text}\"\"\"\nStar: {row[0]}\n\n"
few_shot_examples = few_shot_examples.strip()

In [41]:
def submit_more(it, ex, futures, in_flight_cap, zero_shot_system, few_shot_system):
    """Submit up to (in_flight_cap - len(futures)) new tasks from iterator it."""
    slots = in_flight_cap - len(futures)
    added = 0
    while slots > 0:
        try:
            row = next(it)  # {'row_id':..., 'text':...}
        except StopIteration:
            break
        fut = ex.submit(call_with_retry, int(row[KEY_COL]), str(row[TEXT_COL]),
                        zero_shot_system, few_shot_system)
        futures[fut] = row[KEY_COL]
        slots -= 1
        added += 1
    return added

In [42]:
rows_iter = work_df.iter_rows(named=True)

In [43]:
zero_shot_system = SYSTEM_PROMPT.format_map({"examples": ""})
few_shot_system = SYSTEM_PROMPT.format_map({"examples": few_shot_examples})

buffer = []
completed = 0
in_flight_cap = max(MAX_WORKERS * IN_FLIGHT_FACTOR, MAX_WORKERS)

if work_df.height:
    with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futures: dict[cf.Future, int] = {}
        submit_more(rows_iter, ex, futures, in_flight_cap, zero_shot_system, few_shot_system)  # seed the pool

        # PROCESS EVERYTHING (keep submitting, keep draining)
        while futures:
            done, _pending = cf.wait(set(futures.keys()), return_when=FIRST_COMPLETED)

            # handle all newly-finished futures
            for fut in done:
                rid = futures.pop(fut)  # remove from tracking

                try:
                    zero_shot_res, few_shot_res = fut.result()
                except Exception as e:
                    zero_shot_res = {KEY_COL: rid, "llm_stars": 3, "llm_explanation": f"fallback error: {type(e).__name__}"}
                    few_shot_res  = {KEY_COL: rid, "llm_stars": 3, "llm_explanation": f"fallback error: {type(e).__name__}"}

                if rid not in seen_ids:
                    buffer.append({
                        KEY_COL: rid,
                        "zero_star": zero_shot_res["llm_stars"],
                        "zero_explanation": zero_shot_res["llm_explanation"],
                        "few_star":  few_shot_res["llm_stars"],
                        "few_explanation": few_shot_res["llm_explanation"],
                    })
                    seen_ids.add(rid)
                    completed += 1

                if len(buffer) >= FLUSH_EVERY:
                    buf_df = pl.DataFrame(buffer, schema=schema)
                    ckpt = ckpt.vstack(buf_df)
                    atomic_write_parquet(ckpt, CHECKPOINT_PATH)
                    buffer.clear()
                    print(f"Checkpointed {completed} new rows (total ckpt: {ckpt.height})")

            # top up after draining some
            submit_more(rows_iter, ex, futures, in_flight_cap, zero_shot_system, few_shot_system)

        # final flush
        if buffer:
            buf_df = pl.DataFrame(buffer, schema=schema)
            ckpt = ckpt.vstack(buf_df)
            atomic_write_parquet(ckpt, CHECKPOINT_PATH)
            buffer.clear()
            print(f"Final checkpointed. Total in ckpt: {ckpt.height}")

print(f"All done. Newly processed: {completed}. Total in checkpoint: {ckpt.height}.")

Checkpointed 100 new rows (total ckpt: 100)
the JSON object must be str, bytes or bytearray, not NoneType
Checkpointed 200 new rows (total ckpt: 200)
Checkpointed 300 new rows (total ckpt: 300)
Checkpointed 400 new rows (total ckpt: 400)
Checkpointed 500 new rows (total ckpt: 500)
Checkpointed 600 new rows (total ckpt: 600)
Checkpointed 700 new rows (total ckpt: 700)
Checkpointed 800 new rows (total ckpt: 800)
Checkpointed 900 new rows (total ckpt: 900)
Checkpointed 1000 new rows (total ckpt: 1000)
(content_filter) The response was filtered due to the prompt triggering Microsoft's content management policy. Please modify your prompt and retry.
Code: content_filter
Message: The response was filtered due to the prompt triggering Microsoft's content management policy. Please modify your prompt and retry.
Inner error: {
    "code": "ResponsibleAIPolicyViolation",
    "content_filter_result": {
        "hate": {
            "filtered": false,
            "severity": "safe"
        },
      

In [None]:
llm_df = pl.read_parquet(CHECKPOINT_PATH)
llm_df.head()

row_id,zero_star,zero_explanation,few_star,few_explanation
i64,i64,str,i64,str
5,5,"""Exceptional massage therapist,…",5,"""Exceptional massage therapist …"
2,4,"""Good food and value, but overp…",3,"""Good food but overpriced, with…"
0,4,"""Good food and service, but som…",4,"""Good food and service, but som…"
1,4,"""A bustling diner with a rich h…",4,"""Friendly service, quick turnov…"
7,5,"""The reviewer thoroughly enjoye…",5,"""The reviewer had a delightful …"


In [46]:
all_df = df.join(llm_df, on="row_id", how="left")
all_df.head()

row_id,label,text,zero_star,zero_explanation,few_star,few_explanation
u32,i64,str,i64,str,i64,str
0,4,"""So happy to discover that Edin…",4,"""Good food and service, but som…",4,"""Good food and service, but som…"
1,4,"""We avoided the Saturday Badger…",4,"""A bustling diner with a rich h…",4,"""Friendly service, quick turnov…"
2,4,"""The food is good...but..I thin…",4,"""Good food and value, but overp…",3,"""Good food but overpriced, with…"
3,4,"""Gets the job done .had a Weiss…",4,"""The reviewer found the beer to…",3,"""The reviewer found the beer to…"
4,4,"""The Fountains of Bellagio. Som…",5,"""The reviewer highly recommends…",5,"""A free, beautiful, and iconic …"


In [52]:
all_df.write_parquet("../data/zero_few_shot_review.parquet")

In [49]:
y_test = all_df.select("label").to_numpy()
zero_pred = all_df.select("zero_star").to_numpy()
few_pred = all_df.select("few_star").to_numpy()

In [50]:
print(classification_report(y_test, zero_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.68      0.92      0.78      1000
           2       0.64      0.48      0.54      1000
           3       0.71      0.41      0.52      1000
           4       0.56      0.50      0.53      1000
           5       0.64      0.90      0.75      1000

    accuracy                           0.64      5000
   macro avg       0.54      0.53      0.52      5000
weighted avg       0.65      0.64      0.62      5000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [51]:
print(classification_report(y_test, few_pred))

              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.66      0.90      0.76      1000
           2       0.61      0.47      0.53      1000
           3       0.70      0.48      0.57      1000
           4       0.62      0.63      0.62      1000
           5       0.73      0.81      0.77      1000

    accuracy                           0.66      5000
   macro avg       0.55      0.55      0.54      5000
weighted avg       0.66      0.66      0.65      5000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
