In [1]:
from dotenv import load_dotenv
load_dotenv()

True

In [22]:
import json, os, time, uuid, concurrent.futures as cf
import polars as pl
from concurrent.futures import FIRST_COMPLETED
from typing import Any, Dict, List

from azure.ai.inference import ChatCompletionsClient
from azure.ai.inference.models import SystemMessage, UserMessage
from azure.core.credentials import AzureKeyCredential

from sklearn.metrics import accuracy_score, classification_report

In [3]:
CHECKPOINT_PATH = "../data/direct_vs_cot_ckpt.parquet"
FINAL_PATH = "../data/direct_vs_cot_ckpt.parquet"   # optional merged output
TEXT_COL = "text"
KEY_COL = "row_id"
MAX_WORKERS = 8
MAX_RETRIES = 3
BASE_BACKOFF = 1.0
FLUSH_EVERY = 100 
IN_FLIGHT_FACTOR = 4               # max queued = MAX_WORKERS * IN_FLIGHT_FACTOR

In [4]:
df = pl.read_parquet('../data/zero_few_shot_review.parquet')

In [5]:
# For Serverless API or Managed Compute endpoints
from os import getenv


client = ChatCompletionsClient(
    endpoint="https://portfolio-resource.services.ai.azure.com/models",
    credential=AzureKeyCredential(getenv("AZURE_AI_INFERENCE_KEY")),
    api_version="2024-05-01-preview"
)
model_name = "llama-31_8b"

In [25]:
REASONING_PROMPT = """You are given a Yelp review.
First explain your reasoning about sentiment and content.
Then, rate the review based on the reasoning.

Rubric:
1 = very negative, would not return
2 = negative with some redeeming detail
3 = mixed/neutral
4 = positive with minor issues
5 = very positive, enthusiastic

Respond in the JSON Format:
{{"stars": <int>, "reason": "<1-3 sentence>"}}

Let's think step by step."""

ZERO_SHOT_USER_TMPL = """Review:
\"\"\"{review_text}\"\"\""""

In [26]:
schema = {KEY_COL: pl.Int64, "cot_star": pl.Int64, "reasoning": pl.Utf8}

In [27]:
if os.path.exists(CHECKPOINT_PATH):
    ckpt = pl.read_parquet(CHECKPOINT_PATH).select(
        pl.col(KEY_COL).cast(pl.Int64),
        pl.col("cot_star").cast(pl.Int64),
        pl.col("reasoning").cast(pl.Utf8)
    )
else:
    ckpt = pl.DataFrame({k: pl.Series([], dtype=t) for k, t in schema.items()}, schema=schema)

In [28]:
seen_ids = set(ckpt[KEY_COL].to_list()) if ckpt.height else set()
work_df = df.filter(~pl.col(KEY_COL).is_in(list(seen_ids))).select([KEY_COL, TEXT_COL])

In [29]:
print(f"Total rows: {df.height} | Already done: {len(seen_ids)} | Remaining: {work_df.height}")
if work_df.is_empty():
    # (Optional) produce merged DF and exit
    zero_shot_df = df.join(ckpt, on=KEY_COL, how="left")
    zero_shot_df.write_parquet(FINAL_PATH)
    print(f"No remaining rows. Merged output saved to {FINAL_PATH}")
    # You still have zero_shot_df in memory
    # return/exit if in a script

Total rows: 5000 | Already done: 0 | Remaining: 5000


In [30]:
def get_rating(text: str, system_prompt: str, user: str):
    try:
        response = client.complete(
            messages=[
                SystemMessage(content=system_prompt),
                UserMessage(content=user.format_map({'review_text': text})),
            ],
            max_tokens=128,
            temperature=0.1,
            top_p=0.95,
            model=model_name,
            response_format="json_object"
        )
        output = response.choices[0].message.content
        json_output = json.loads(output)
        return {"cot_star": int(json_output['stars']), "reasoning": json_output['reason']}
    except Exception as ex:
        print(str(ex))
        return {"cot_star": 0, "reasoning": ""}


In [31]:
def call_with_retry(row_id: int, text: str, cot_system: str) -> Dict[str, Any]:
    for attempt in range(MAX_RETRIES):
        try:
            cot_out = get_rating(text, cot_system, ZERO_SHOT_USER_TMPL)
            cot_out[KEY_COL] = row_id
            return cot_out
        except Exception as e:
            sleep_s = BASE_BACKOFF * (2 ** attempt) * (1 + 0.15 * attempt)
            time.sleep(sleep_s)
            if attempt == MAX_RETRIES - 1:
                attemptt_response = {
                    KEY_COL: row_id,
                    "cot_star": 3,  # fallback neutral
                    "reasoning": f"fallback after error: {type(e).__name__}"
                }
                return attemptt_response

In [32]:
def atomic_write_parquet(df: pl.DataFrame, path: str):
    tmp = f"{path}.tmp.{uuid.uuid4().hex}.parquet"
    df.write_parquet(tmp)
    os.replace(tmp, path)  

In [33]:
def submit_more(it, ex, futures, in_flight_cap, cot_system):
    """Submit up to (in_flight_cap - len(futures)) new tasks from iterator it."""
    slots = in_flight_cap - len(futures)
    added = 0
    while slots > 0:
        try:
            row = next(it)  # {'row_id':..., 'text':...}
        except StopIteration:
            break
        fut = ex.submit(call_with_retry, int(row[KEY_COL]), str(row[TEXT_COL]),
                        cot_system)
        futures[fut] = row[KEY_COL]
        slots -= 1
        added += 1
    return added

In [34]:
rows_iter = work_df.iter_rows(named=True)

In [35]:
buffer = []
completed = 0
in_flight_cap = max(MAX_WORKERS * IN_FLIGHT_FACTOR, MAX_WORKERS)

if work_df.height:
    with cf.ThreadPoolExecutor(max_workers=MAX_WORKERS) as ex:
        futures: dict[cf.Future, int] = {}
        submit_more(rows_iter, ex, futures, in_flight_cap, REASONING_PROMPT)  # seed the pool

        # PROCESS EVERYTHING (keep submitting, keep draining)
        while futures:
            done, _pending = cf.wait(set(futures.keys()), return_when=FIRST_COMPLETED)

            # handle all newly-finished futures
            for fut in done:
                rid = futures.pop(fut)  # remove from tracking

                try:
                    cot_res = fut.result()
                except Exception as e:
                    cot_res = {KEY_COL: rid, "cot_star": 3, "reasoning": f"fallback error: {type(e).__name__}"}

                if rid not in seen_ids:
                    buffer.append({
                        KEY_COL: rid,
                        **cot_res
                    })
                    seen_ids.add(rid)
                    completed += 1

                if len(buffer) >= FLUSH_EVERY:
                    buf_df = pl.DataFrame(buffer, schema=schema)
                    ckpt = ckpt.vstack(buf_df)
                    atomic_write_parquet(ckpt, CHECKPOINT_PATH)
                    buffer.clear()
                    print(f"Checkpointed {completed} new rows (total ckpt: {ckpt.height})")

            # top up after draining some
            submit_more(rows_iter, ex, futures, in_flight_cap, REASONING_PROMPT)

        # final flush
        if buffer:
            buf_df = pl.DataFrame(buffer, schema=schema)
            ckpt = ckpt.vstack(buf_df)
            atomic_write_parquet(ckpt, CHECKPOINT_PATH)
            buffer.clear()
            print(f"Final checkpointed. Total in ckpt: {ckpt.height}")

print(f"All done. Newly processed: {completed}. Total in checkpoint: {ckpt.height}.")

Checkpointed 100 new rows (total ckpt: 100)
Checkpointed 200 new rows (total ckpt: 200)
Checkpointed 300 new rows (total ckpt: 300)
Checkpointed 400 new rows (total ckpt: 400)
Checkpointed 500 new rows (total ckpt: 500)
Checkpointed 600 new rows (total ckpt: 600)
Checkpointed 700 new rows (total ckpt: 700)
Unterminated string starting at: line 1 column 24 (char 23)
Checkpointed 800 new rows (total ckpt: 800)
Invalid \uXXXX escape: line 1 column 276 (char 275)
Checkpointed 900 new rows (total ckpt: 900)
Checkpointed 1000 new rows (total ckpt: 1000)
(content_filter) The response was filtered due to the prompt triggering Microsoft's content management policy. Please modify your prompt and retry.
Code: content_filter
Message: The response was filtered due to the prompt triggering Microsoft's content management policy. Please modify your prompt and retry.
Inner error: {
    "code": "ResponsibleAIPolicyViolation",
    "content_filter_result": {
        "hate": {
            "filtered": false

In [36]:
llm_df = pl.read_parquet(CHECKPOINT_PATH)
llm_df.head()

row_id,cot_star,reasoning
i64,i64,str
1,5,"""The reviewer had a very positi…"
5,5,"""The reviewer highly recommends…"
8,4,"""The reviewer is generally very…"
4,5,"""The reviewer highly recommends…"
2,4,"""The reviewer has mixed opinion…"


In [37]:
all_df = df.join(llm_df, on="row_id", how="left")
all_df.head()

row_id,label,text,zero_star,zero_explanation,few_star,few_explanation,cot_star,reasoning
u32,i64,str,i64,str,i64,str,i64,str
0,4,"""So happy to discover that Edin…",4,"""Good food and service, but som…",4,"""Good food and service, but som…",4,"""The reviewer is very happy wit…"
1,4,"""We avoided the Saturday Badger…",4,"""A bustling diner with a rich h…",4,"""Friendly service, quick turnov…",5,"""The reviewer had a very positi…"
2,4,"""The food is good...but..I thin…",4,"""Good food and value, but overp…",3,"""Good food but overpriced, with…",4,"""The reviewer has mixed opinion…"
3,4,"""Gets the job done .had a Weiss…",4,"""The reviewer found the beer to…",3,"""The reviewer found the beer to…",4,"""The reviewer seems to be gener…"
4,4,"""The Fountains of Bellagio. Som…",5,"""The reviewer highly recommends…",5,"""A free, beautiful, and iconic …",5,"""The reviewer highly recommends…"


In [38]:
all_df.write_parquet("../data/zero_cot_review.parquet")

In [39]:
y_test = all_df.select("label").to_numpy()
zero_pred = all_df.select("zero_star").to_numpy()
cot_pred = all_df.select("cot_star").to_numpy()

In [40]:
print("Accuracy: ", accuracy_score(y_test, zero_pred))
print(classification_report(y_test, zero_pred))

Accuracy:  0.6412
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.68      0.92      0.78      1000
           2       0.64      0.48      0.54      1000
           3       0.71      0.41      0.52      1000
           4       0.56      0.50      0.53      1000
           5       0.64      0.90      0.75      1000

    accuracy                           0.64      5000
   macro avg       0.54      0.53      0.52      5000
weighted avg       0.65      0.64      0.62      5000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [41]:
print("Accuracy: ", accuracy_score(y_test, cot_pred))
print(classification_report(y_test, cot_pred))

Accuracy:  0.6194
              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       0.68      0.91      0.78      1000
           2       0.64      0.46      0.53      1000
           3       0.69      0.36      0.47      1000
           4       0.50      0.45      0.47      1000
           5       0.61      0.92      0.74      1000

    accuracy                           0.62      5000
   macro avg       0.52      0.52      0.50      5000
weighted avg       0.63      0.62      0.60      5000



  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


# Report:
- Zero-shot COT reasoning reduced the classification than direct zero-shot classification
- Reason: The reviews have a hidden meaning behind then which cannot be extracted unless we have example to provide as to why a review is positive or negative.