In [1]:
import requests
from pydantic import BaseModel
from datetime import datetime
import pandas as pd

In [2]:
# Define the Pydantic input data model
class InputData(BaseModel):
    fixed_acidity: float
    volatile_acidity: float
    citric_acid: float
    residual_sugar: float
    chlorides: float
    free_sulfur_dioxide: float
    total_sulfur_dioxide: float
    density: float
    pH: float
    sulphates: float
    alcohol: float


In [3]:
# Create sample input data
sample_data = {
    "fixed_acidity": 7.9,
    "volatile_acidity": 0.8,
    "citric_acid": 1.0,
    "residual_sugar": 1.9,
    "chlorides": 0.076,
    "free_sulfur_dioxide": 11.0,
    "total_sulfur_dioxide": 34.0,
    "density": 0.9978,
    "pH": 3.51,
    "sulphates": 0.56,
    "alcohol": 9.4
}

In [4]:
# Instantiate the input data model
input_data = InputData(**sample_data)

endpoint_url = "http://localhost:8000/predict/"

response = requests.post(endpoint_url, json=input_data.dict())
if response.status_code == 200:
    prediction = response.json()[0]["prediction"]
    print("Prediction:", prediction)
else:
    print("Error:", response.text)

Prediction: 5.03


In [5]:
import pandas as pd

# Read the CSV file into a DataFrame
df = pd.read_csv('../data/test-data.csv')
df.columns = df.columns.str.replace(' ', '_')

input_data_list = []
for _, row in df.iterrows():
    input_data = InputData(**row.to_dict())
    input_data_list.append(input_data)

# Send a POST request to the predict_batch endpoint
endpoint_url = "http://localhost:8000/predict/"
response = requests.post(endpoint_url, json=[data.dict() for data in input_data_list])

if response.status_code == 200:
    predictions = response.json()
    print("Predictions:")
    for i, pred in enumerate(predictions):
        print(f"{i}: {pred['prediction']}")
else:
    print("Error:", response.text)

Predictions:
0: 5.01
1: 5.05
2: 5.06
3: 5.75
4: 5.01
5: 5.03
6: 4.99
7: 6.27


In [6]:
# Test get past predictions
start_date = datetime(2024, 1, 1)
end_date = datetime(2024, 4, 3)

endpoint_url = "http://localhost:8000/get_past_predictions/"

response = requests.get(endpoint_url, params={"start_date": start_date, "end_date": end_date})

if response.status_code == 200:
    df = pd.DataFrame(response.json())
    
    print("Past Predictions:")
    print(df)
else:
    print("Error:", response.text)


Past Predictions:
    fixed_acidity  volatile_acidity  citric_acid  residual_sugar  chlorides  \
0             7.9              0.80         1.00             1.9      0.076   
1             7.9              0.80         1.00             1.9      0.076   
2             7.4              0.70         0.00             1.9      0.076   
3             7.8              0.88         0.00             2.6      0.098   
4             7.8              0.76         0.04             2.3      0.092   
5            11.2              0.28         0.56             1.9      0.075   
6             7.4              0.70         0.00             1.9      0.076   
7             7.4              0.66         0.00             1.8      0.075   
8             7.9              0.60         0.06             1.6      0.069   
9             7.3              0.65         0.00             1.2      0.065   
10            7.9              0.80         1.00             1.9      0.076   
11            7.9              0.8