In [1]:
import os
import base64
from PIL import Image
from io import BytesIO
from langchain_openai import ChatOpenAI
from scripts.vllm_server import VLLMServer
import json
from pathlib import Path
from pydantic import BaseModel, Field
from sklearn.metrics import mean_absolute_error
from tqdm.notebook import tqdm
import random

random.seed(42)

In [2]:
AVAILABLE_MODELS = ["Qwen2.5-VL-7B-Instruct", "gemma-3-27b-it"]


MODEL_NAME = AVAILABLE_MODELS[1]
assert MODEL_NAME in AVAILABLE_MODELS

In [3]:
config_path = os.path.abspath(f"configs/{MODEL_NAME}.yaml")

In [4]:
server = VLLMServer(config_path)

Убиваем старый процесс vLLM с PGID=1296976
🚀 Запуск vLLM-сервера на localhost:1337 с моделью /home/student/vllm_models/google/gemma-3-27b-it…
⏳ Ожидание порта 1337 на localhost…
✅ Порт 1337 на localhost открыт — сервер готов.
📜 Логи сервера пишутся в файл: vllm_server.log


In [5]:
EXP_NAME = f"{MODEL_NAME}_zeroshot"
ARTIFACTS_DIR = "vlm_experiments"

In [6]:
llm = ChatOpenAI(
    base_url="http://localhost:1337/v1",
    api_key="test",
    model_name=MODEL_NAME,
    temperature=0.,
    max_tokens=32768
)

In [7]:
llm.invoke("Привет!")

AIMessage(content='Привет! Чем могу помочь? 😊\n', additional_kwargs={'refusal': None}, response_metadata={'token_usage': {'completion_tokens': 9, 'prompt_tokens': 11, 'total_tokens': 20, 'completion_tokens_details': None, 'prompt_tokens_details': None}, 'model_name': 'gemma-3-27b-it', 'system_fingerprint': None, 'id': 'chatcmpl-4a824330cebf4f5ca227fe75c971bb87', 'service_tier': None, 'finish_reason': 'stop', 'logprobs': None}, id='run--99031e9d-c19b-4e7e-ba4f-a9c46cf6056e-0', usage_metadata={'input_tokens': 11, 'output_tokens': 9, 'total_tokens': 20, 'input_token_details': {}, 'output_token_details': {}})

In [8]:
class NutritionFacts(BaseModel):
    energy_kcal: float = Field(description="Calories per 100g, in kcal")
    protein_g: float = Field(description="Protein per 100g, in grams")
    fat_g: float = Field(description="Fat per 100g, in grams")
    carbs_g: float = Field(description="Carbohydrates per 100g, in grams")

In [9]:
llm = llm.with_structured_output(NutritionFacts.model_json_schema())

In [10]:
def image_to_base64(image_path: str, images_dir_path: str) -> str:
    image_path = Path(images_dir_path) / Path(image_path.replace('\\', '/'))
    with Image.open(image_path) as img:
        buffered = BytesIO()
        img.save(buffered, format="JPEG")
        img_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8")
    return img_base64

In [11]:
with open('data/test_fix.json', 'r') as f:
    raw_data = json.load(f)


processed_data = []
sampled_batch  = random.sample(raw_data, 500)

for sample in sampled_batch:
    title = sample["title"]
    images = [image_to_base64(image["valid_path"], "data/test_valid") for image in sample["images"]]
    nutrition_targets = sample["nutr_per100g"]

    kcal_bju_keys = ['energy_kcal', 'protein_g', 'fat_g', 'carbs_g']
    kcal_bju = {k: nutrition_targets[k] for k in kcal_bju_keys}
    processed_data.append({"title": title, "images": images, "nutrition_targets": kcal_bju})

In [12]:
SYSTEM_PROMPT = \
"""
You are a nutrition expert.  
Your task is to analyze an image of a dish and estimate its nutritional value per 100 grams: calories, protein, fat, and carbohydrates.

If the name of the dish is provided, use it together with the image to make a more accurate assessment.
"""

USER_PROMPT = "Determine the amount of calories, protein, fat, and carbohydrates per 100 grams for the dish name `{title}`."

In [13]:
def build_query(dish_data):
    system_message = {
        "role": "system",
        "content": SYSTEM_PROMPT,
    }

    text_user_content = [
        {
            "type": "text",
            "text": USER_PROMPT.format(title=dish_data["title"]),
        }
    ]

    images_user_content = [
        {
            "type": "image",
            "source_type": "base64",
            "data": image,
            "mime_type": "image/jpeg",
        }
        for image in dish_data["images"][:1] # choose one image for dish
    ]

    message = [
        system_message,

        {
            "role": "user",
            "content": text_user_content + images_user_content,
        }
    ]
    return message

In [14]:
message = build_query(processed_data[0])

In [15]:
response = llm.invoke(message)
print(response)

{'energy_kcal': 280, 'protein_g': 2.5, 'fat_g': 12, 'carbs_g': 40}


In [16]:
targets = []
predictions = []

for sample in tqdm(processed_data):
    message = build_query(sample)
    targets.append(sample["nutrition_targets"])
    predictions.append(llm.invoke(message))

  0%|          | 0/500 [00:00<?, ?it/s]

In [17]:
nutrients = list(targets[0].keys())
y_true = {nutrient: [t[nutrient] for t in targets] for nutrient in nutrients}
y_pred = {nutrient: [p[nutrient] for p in predictions] for nutrient in nutrients}

mae_scores = {}

for nutrient in nutrients:
    mae = mean_absolute_error(y_true[nutrient], y_pred[nutrient])
    mae_scores[nutrient] = mae

os.makedirs(f"{ARTIFACTS_DIR}/{EXP_NAME}", exist_ok=True)
with open(f"{ARTIFACTS_DIR}/{EXP_NAME}/metrics.json", "w", encoding="utf-8") as f:
    json.dump(mae_scores, f, indent=4, ensure_ascii=False)

In [15]:
server.stop()

🛑 Остановка vLLM-сервера (SIGINT группе)…
✅ Сервер завершился корректно.
