In [None]:
import base64
import pandas as pd
from mistralai import Mistral
from dotenv import load_dotenv
from pydantic import BaseModel
import json
import os
import time
import numpy as np

pd.options.mode.chained_assignment = None

In [2]:
load_dotenv()
# Retrieve the API key from environment variables
api_key = os.environ["MISTRAL_API_KEY"]
#model = "pixtral-12b-2409"

client = Mistral(api_key=api_key)
# Specify model
model = "pixtral-large-latest"

In [3]:
class MacronutrientEstimations(BaseModel):
    simple_sugars: int
    complex_sugars: int
    proteins: int
    fats: int
    dietary_fibers: int
    explanation: str

macronutrients_instruction = f'''Examine the provided meal image to analyze and estimate its nutritional content accurately. 
Determine the amounts of simple sugars (like industrial sugar and honey), complex sugars (such as starch and whole grains), proteins, fats, and dietary fibers (found in fruits and vegetables), all in grams.
To assist in accurately gauging the scale of the meal, a 1 Swiss Franc coin, which has a diameter of 23 mm, may be present in the picture. 
Provide your assessment of each nutritional component in grams. All estimates should be given as a single whole number. 
Use the size of this coin as a reference to estimate the size of the meal and the amounts of the nutrients more precisely. 
If there is no coin in the picture or the meal is covered partially, estimate anyways.


Here are three examples of how to analyze different types of foods:

Example 1 - Pizza:
{MacronutrientEstimations(
    simple_sugars=6,
    complex_sugars=65,
    proteins=28,
    fats=32,
    dietary_fibers=5,
    explanation="This appears to be a medium-sized cheese pizza (about 12 inches in diameter). The crust provides most of the complex sugars (approximately 65g from flour). The tomato sauce contains some simple sugars (6g). The cheese provides most of the protein (28g) and fats (32g). The small amount of tomato sauce and potentially some vegetables contribute to the dietary fiber content (5g)."
)}

Example 2 - Bread with cheese and salad:
{MacronutrientEstimations(
    simple_sugars=4,
    complex_sugars=30,
    proteins=18,
    fats=15,
    dietary_fibers=8,
    explanation="The meal consists of a slice of bread with cheese and a side salad. The bread provides most of the complex sugars (30g from wheat flour). The cheese contributes protein (10g) and fats (10g). The salad provides dietary fibers (8g) and a small amount of simple sugars (4g from vegetables). The bread also contributes some protein (8g) and fats (5g)."
)}

Example 3 - Bottle of orange juice:
{MacronutrientEstimations(
    simple_sugars=42,
    complex_sugars=2,
    proteins=2,
    fats=0,
    dietary_fibers=1,
    explanation="This appears to be a standard 500ml bottle of orange juice. Orange juice is primarily composed of simple sugars (42g from natural fruit sugars). It contains minimal complex sugars (2g), a small amount of protein (2g), negligible fat (0g), and a small amount of dietary fiber (1g) from pulp residue in the juice."
)}

Format your response in the following JSON format:
{MacronutrientEstimations(
    simple_sugars=0,
    complex_sugars=0,
    proteins=0,
    fats=0,
    dietary_fibers=0,
    explanation=''
)}'''

In [4]:
def encode_image(image_path):
    with open(image_path, "rb") as image_file:
        return "data:image/jpeg;base64," + base64.b64encode(image_file.read()).decode('utf-8')
    
def make_mistral_api_call(image_path):
    # Define the messages for the chat
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "text",
                    "text": macronutrients_instruction
                },
                {
                    "type": "image_url",
                    "image_url": image_path,
                    "response_format": MacronutrientEstimations
                }
            ]
        }
    ]

    # Get the chat response
    chat_response = client.chat.parse(
        model=model,
        messages=messages,
        response_format=MacronutrientEstimations,
        max_tokens=1024,
        temperature=0
    )
    response_json = chat_response.choices[0].message.content
    response = json.loads(response_json.strip('`json\n'))
    return response

In [5]:
patient = '001'
nutrient_features = ['simple_sugars', 'complex_sugars', 'proteins', 'fats', 'dietary_fibers']
selected_images = ['001.jpg', '008.jpg', '010.jpg', '022.jpg', '025.jpg']
n_runs = 100

all_estimates = []

for image in selected_images:
    print(f"Processing {image}...")
    image_path = f"../diabetes_subset_pictures-glucose-food-insulin/{patient}/food_pictures/{image}"
    base64_image = encode_image(image_path)
    
    for run in range(n_runs):
        print(f"  Run {run + 1}/{n_runs}", end='\r')
        try:
            message = make_mistral_api_call(base64_image)
            estimate = {
                'image': image,
                'run': run + 1,
                'simple_sugars': message['simple_sugars'],
                'complex_sugars': message['complex_sugars'],
                'proteins': message['proteins'],
                'fats': message['fats'],
                'dietary_fibers': message['dietary_fibers'],
                'explanation': message['explanation']
            }
            all_estimates.append(estimate)
            time.sleep(0.1)  # Small delay to avoid rate limiting
        except Exception as e:
            print(f"\nError on run {run + 1}: {e}")
            continue
    print(f"  Completed {image}")

estimates_df = pd.DataFrame(all_estimates)

Processing 001.jpg...
  Completed 001.jpg
Processing 008.jpg...
  Run 39/100
Error on run 39: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Completed 008.jpg
Processing 010.jpg...
  Completed 010.jpg
Processing 022.jpg...
  Run 31/100
Error on run 31: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Run 62/100
Error on run 62: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Run 95/100
Error on run 95: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Completed 022.jpg
Processing 025.jpg...
  Run 53/100
Error on run 53: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Run 54/100
Error on run 54: [SSL: SSLV3_ALERT_BAD_RECORD_MAC] ssl/tls alert bad record mac (_ssl.c:2638)
  Completed 025.jpg


In [34]:
estimates_df.to_csv('raw_results.csv', index=False)

In [36]:

# Calculate total macronutrients for each run
nutrient_cols = ['simple_sugars', 'complex_sugars', 'proteins', 'fats', 'dietary_fibers']
estimates_df['total_macronutrients'] = estimates_df[nutrient_cols].sum(axis=1)

# Calculate statistics for each image
images = ['001.jpg', '008.jpg', '010.jpg', '022.jpg', '025.jpg']
stats = []

for image in images:
    image_data = estimates_df[estimates_df['image'] == image]['total_macronutrients']
    stats.append({
        'image': image.replace('.jpg', ''),
        'mean': image_data.mean(),
        'median': image_data.median(),
        'std': image_data.std(),
        'cv': (image_data.std() / image_data.mean()) * 100,
        'min': image_data.min(),
        'max': image_data.max()
    })

# Create LaTeX table
latex_table = """\\begin{table}[htbp]
\\centering
\\caption{Total Macronutrient Reproducibility Analysis}
\\label{tab:total_macronutrient_reproducibility}
\\begin{tabular}{lcccccc}
\\toprule
Image & Mean (g) & Median (g) & Std (g) & CV (\\%) & Min (g) & Max (g) \\\\
\\midrule
"""

for stat in stats:
    cv_str = f"{stat['cv']:.1f}" if not np.isnan(stat['cv']) else "N/A"
    latex_table += f"{stat['image']} & {stat['mean']:.1f} & {stat['median']:.1f} & {stat['std']:.2f} & {cv_str} & {stat['min']:.1f} & {stat['max']:.1f} \\\\\n"
latex_table += """\\bottomrule
\\end{tabular}
\\end{table}"""

# Save table
with open('reproducibility_table.tex', 'w') as f:
    f.write(latex_table)