In [None]:
import pandas as pd
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import root_mean_squared_error
import pickle
import os
from flask import Flask, request, jsonify
import numpy as np
import matplotlib.pyplot as plt
import io
import base64

# =====================
# 1. Train and Save Model
# =====================
def train_model(data_path=r"C:\Users\sarda\Desktop\bootcamp_darshit_sarda\homework\data\processed\amazon_bestsellers_2025_cleaned.csv", model_path="model.pkl"):
    df = pd.read_csv(data_path)

    # Clean price column: remove ₹, $, commas
    df["product_price"] = (
        df["product_price"]
        .astype(str)
        .str.replace("₹", "", regex=False)
        .str.replace("$", "", regex=False)
        .str.replace(",", "", regex=False)
        .str.strip()
    )

    # Drop rows where product_price is missing or invalid
    df = df[df["product_price"].str.replace(".", "", 1).str.isnumeric()]
    df["product_price"] = df["product_price"].astype(float)

    X = df[["product_price"]]
    y = df["product_num_ratings"]

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    model = LinearRegression()
    model.fit(X_train, y_train)

    y_pred = model.predict(X_test)
    rmse = root_mean_squared_error(y_test, y_pred)
    print(f"Model trained. RMSE = {rmse:.2f}")

    with open(model_path, "wb") as f:
        pickle.dump(model, f)

    return model


# Train and save once
if not os.path.exists("model.pkl"):
    model = train_model()
else:
    with open("model.pkl", "rb") as f:
        model = pickle.load(f)

# =====================
# 2. Flask API
# =====================
app = Flask(__name__)

@app.route("/predict", methods=["POST"])
def predict():
    data = request.get_json(force=True)
    price = data.get("Price", None)
    if price is None:
        return jsonify({"error": "Missing Price"}), 400
    pred = model.predict(np.array([[price]]))[0]
    return jsonify({"Price": price, "Predicted_Reviews": float(pred)})




@app.route("/predict/<float:price>", methods=["GET"])
def predict_single(price):
    pred = model.predict(pd.DataFrame({"product_price": [price]}))[0]
    return jsonify({"Price": price, "Predicted_Reviews": float(pred)})


@app.route("/plot", methods=["GET"])
def plot():
    prices = np.linspace(0, 2000, 50).reshape(-1, 1)  # Expanded price range
    preds = model.predict(prices)

    plt.figure()
    plt.plot(prices, preds, label="Prediction Curve")
    plt.xlabel("Price")
    plt.ylabel("Predicted Reviews")
    plt.legend()

    buf = io.BytesIO()
    plt.savefig(buf, format="png")
    buf.seek(0)
    img_base64 = base64.b64encode(buf.getvalue()).decode("utf-8")
    return f"<img src='data:image/png;base64,{img_base64}'/>"

if __name__ == "__main__":
    app.run(debug=True, use_reloader=False)


 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [28/Aug/2025 15:21:09] "GET / HTTP/1.1" 404 -
127.0.0.1 - - [28/Aug/2025 15:21:09] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [28/Aug/2025 15:25:00] "GET /plot HTTP/1.1" 200 -
127.0.0.1 - - [28/Aug/2025 15:25:00] "GET /favicon.ico HTTP/1.1" 404 -
127.0.0.1 - - [28/Aug/2025 15:27:16] "GET / HTTP/1.1" 404 -
