In [3]:
import asyncio

import tenacity
from PIL.Image import Image
%reload_ext autoreload
%autoreload 2
%autoreload now

import logging
import re

from semaphore import set_semaphore
from utils_openai import call_gpt4

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from dotenv import load_dotenv
from openai import AsyncOpenAI

from torchvision import datasets

import pandas as pd

load_dotenv()

True

In [4]:
set_semaphore("gpt-4", 100)

In [5]:
client = AsyncOpenAI()

In [6]:
fashion_mnist = datasets.FashionMNIST(root='./data', train=False, download=True)

In [7]:
fashion_mnist[0]

(<PIL.Image.Image image mode=L size=28x28>, 9)

In [8]:
fashion_mnist_labels = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'
}

In [8]:
print(PROMPT_TEMPLATE)

In [44]:
LABELS_MESSAGE = "\n".join([f"{k}: {v}" for k, v in fashion_mnist_labels.items()])
PROMPT_TEMPLATE = f"""You are presented with a greyscale low-resolution image of a piece of clothing.
It is of one of the possible classes:
{LABELS_MESSAGE}

Your task is to classify the image into one of the classes.
Format your output like this:
First, reason about what features let you classify the image correctly. If there's some ambiguity - reason about it and try to come to the most probable conclusion. Your reasoning should be 2-4 sentences at max.
You last line should always be class label, nothing else.

Your last line should always be a class label from the list above! If you are not sure - try to make your best guess. 
"""

In [49]:
def log_before_sleep(retry_state):
    logging.info(f"Retrying: attempt #{retry_state.attempt_number}, waiting {retry_state.next_action.sleep} seconds due to {retry_state.outcome.exception()}")


@tenacity.retry(
    wait=tenacity.wait_fixed(1),
    stop=tenacity.stop_after_attempt(10),
    retry=tenacity.retry_if_exception_type(Exception),
    before_sleep=log_before_sleep,
)
async def get_img_label(img: Image, ind: int) -> tuple[int, str, int]:
    response = await call_gpt4(
        client,
        PROMPT_TEMPLATE,
        imgs=[
            img,
        ]
    )
    try:
        label = int(re.search(r"\d", response).group())
    except:
        return 0, response, ind
    reasoning = "\n".join(response.split("\n")[:-1]).strip()
    return label, reasoning, ind

In [50]:
results = [None] * len(fashion_mnist) # Each element shall be {"label_pred": int, "label_true": int, "reasoning": str}
# tasks = [get_img_label(img, i) for i, (img, _) in enumerate(fashion_mnist)]
# tasks = [get_img_label(img, i) for i, (img, _) in enumerate([fashion_mnist[i] for i in range(10)])]
# tasks = [get_img_label(img, i) for i, (img, _) in enumerate(fashion_mnist) if results[i] is None]
for i, task in enumerate(asyncio.as_completed(tasks)):
    if i % 100 == 0:
        print(f"Processed {i} images")
    label_pred, reasoning, ind = await task
    label_true = fashion_mnist[ind][1]
    results[ind] = {
        "ind": ind,
        "label_pred": label_pred,
        "label_true": label_true,
        "reasoning": reasoning
    }

Processed 0 images


2024-09-07 19:06:58,679 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:58,736 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:58,740 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:58,743 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:58,855 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:58,941 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:59,023 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:59,026 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2024-09-07 19:06:59,049 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "

In [51]:
len([get_img_label(img, i) for i, (img, _) in enumerate(fashion_mnist) if results[i] is None])

0

In [52]:
pd.DataFrame(results).to_csv("results.csv", index=False)

In [53]:
import os

# Those 30 images where not classified unless you specifically ask LLM to ALWAYS guess a class label
bad_imgs = [(i, img) for i, (img, _) in enumerate(fashion_mnist) if results[i] is None]
os.makedirs("bad_imgs", exist_ok=True)
for i, img in bad_imgs:
    img.save(os.path.join("bad_imgs", f'image_{i}.png'))

In [54]:
from sklearn.metrics import accuracy_score, classification_report

y_pred = [res['label_pred'] for res in results]
y_true = [res['label_true'] for res in results]

accuracy = accuracy_score(y_true, y_pred)
print(f"Accuracy: {accuracy:.4f}")

report = classification_report(y_true, y_pred)
print("Classification Report:")
print(report)

Accuracy: 0.8085
Classification Report:
              precision    recall  f1-score   support

           0       0.68      0.70      0.69      1000
           1       0.99      0.98      0.99      1000
           2       0.69      0.88      0.78      1000
           3       0.73      0.93      0.81      1000
           4       0.86      0.56      0.68      1000
           5       0.96      0.86      0.91      1000
           6       0.52      0.51      0.51      1000
           7       0.82      0.98      0.89      1000
           8       0.99      0.82      0.90      1000
           9       0.97      0.87      0.92      1000

    accuracy                           0.81     10000
   macro avg       0.82      0.81      0.81     10000
weighted avg       0.82      0.81      0.81     10000

