In [2]:
!pip install -q fastapi "uvicorn[standard]" nest_asyncio scikit-learn

### Setup: Install required libraries

In this step we install the core tools we’ll use to simulate a **real web + ML stack**:

- **FastAPI** – Python web framework to build our backend API.
- **uvicorn** – ASGI server that runs the FastAPI app.
- **scikit-learn** – to train simple ML models.
- **nest_asyncio** – lets us run the server cleanly inside a Colab notebook.

This is the only “environment setup” step. After this, everything else is just Python and JavaScript.


In [3]:
from fastapi import FastAPI
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from typing import Optional
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor
from sklearn.svm import SVR
from sklearn.metrics import r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
import nest_asyncio
import uvicorn
from google.colab import output
import threading, time

###  Define the regression problem & request schema

In this step we set up the **ML side** of our demo and describe what kind of
data the backend expects from the frontend.

**What happens here:**

1. **Create a simple regression dataset**

   We generate synthetic data following an approximately linear relationship:
   \[
   y = 3x + 5 + \text{noise}
   \]
   - `X` goes from 0 to 10.
   - We add Gaussian noise so the fit is not perfect (more realistic).
   - This keeps the problem easy to visualize and reason about.

2. **Train/test split**

   We split the data into:
   - `X_train, y_train` → used to fit models.
   - `X_test, y_test` → used later to evaluate how good each model is
     (R² and MAE).

3. **Grid for plotting**

   - `grid_x` is a dense range of x-values.
   - We’ll use it to compute smooth predictions and draw the **fitted curve**
     for whichever model + hyperparameters the user selects.

4. **Define the API request model (`PredictRegRequest`)**

   This `BaseModel` (from Pydantic) describes the JSON payload our backend
   expects when the frontend calls the `/predict_reg` endpoint:

   - `model_name`: which model to use  
     (`"linear"`, `"rf"`, or `"svr"`).
   - `x_value`: the input x at which we want a prediction.
   - Optional hyperparameters (only used for some models):
     - `n_estimators`, `max_depth` for **Random Forest**.
     - `C`, `gamma` for **SVR (RBF)**.

   FastAPI will:
   - Automatically parse incoming JSON into this Python object.
   - Validate types for us.
   - Make it easy to access all values in our endpoint function.

Together, this cell defines the **problem**, the **evaluation setup**, and the
**contrac**


In [4]:
# ---------- Simple regression dataset ----------
np.random.seed(42)
X = np.linspace(0, 10, 80)
y = 3 * X + 5 + np.random.normal(scale=4, size=80)  # noisy linear-ish
X = X.reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

grid_x = np.linspace(0, 10, 200).reshape(-1, 1)  # for smooth fitted line


class PredictRegRequest(BaseModel):
    model_name: str          # "linear", "rf", "svr"
    x_value: float
    # Random Forest
    n_estimators: Optional[int] = None
    max_depth: Optional[int] = None
    # SVR
    C: Optional[float] = None
    gamma: Optional[float] = None

### Create the FastAPI application

Here we initialize our web application:

```python
app = FastAPI()

In [5]:
app = FastAPI()

### Build the Frontend: HTML + JavaScript inside a FastAPI route

This endpoint defines the **entire web interface** for our mini app.

```python
@app.get("/", response_class=HTMLResponse)
def home():
    return """ ... """


In [6]:
@app.get("/", response_class=HTMLResponse)
def home():
    return """<!DOCTYPE html>
<html>
<head>
  <meta charset="UTF-8" />
  <title>SAI ML Regression Demo</title>
  <script src="https://cdn.jsdelivr.net/npm/chart.js"></script>
  <style>
    body { font-family: Arial, sans-serif; margin: 30px; max-width: 980px; }
    h1 { color: #2c3e50; margin-bottom: 5px; }
    .subtitle { color: #555; margin-bottom: 20px; }
    .layout {
      display: flex;
      gap: 24px;
      align-items: flex-start;
      flex-wrap: wrap;
    }
    .left { flex: 2 1 520px; }
    .right { flex: 1 1 280px; }

    .chart-wrapper {
      width: 100%;
      height: 320px;
      position: relative;
    }
    #regChart {
      width: 100% !important;
      height: 100% !important;
    }

    label { display: block; margin-top: 10px; font-weight: 500; }
    select, input {
      padding: 6px;
      margin-top: 4px;
      width: 100%;
      box-sizing: border-box;
      border-radius: 4px;
      border: 1px solid #ccc;
      font-size: 0.9em;
    }
    .hyper-section {
      margin-top: 8px;
      padding: 8px;
      border-radius: 4px;
      background: #fafafa;
      border: 1px solid #eee;
      font-size: 0.8em;
    }
    .hyper-title {
      font-weight: 600;
      margin-bottom: 4px;
      color: #444;
    }
    .hyper-row {
      display: flex;
      gap: 6px;
      align-items: center;
      margin-top: 4px;
    }
    .hyper-row span {
      flex: 1 0 60px;
    }
    .hyper-row input {
      flex: 1 0 80px;
      margin-top: 0;
    }

    button {
      margin-top: 14px;
      padding: 10px 16px;
      cursor: pointer;
      border: none;
      background: #2c3e50;
      color: #fff;
      border-radius: 4px;
      font-size: 0.9em;
    }
    button:hover { background: #1a242f; }

    table {
      width: 100%;
      border-collapse: collapse;
      margin-top: 16px;
      font-size: 0.9em;
    }
    th, td {
      padding: 6px 8px;
      border-bottom: 1px solid #eee;
      text-align: left;
    }
    th { background: #f4f4f4; }
    .metric-label { color: #444; }
    .note { margin-top: 10px; font-size: 0.82em; color: #666; }
    .pill {
      display: inline-block;
      padding: 2px 6px;
      border-radius: 10px;
      background: #eef3ff;
      font-size: 0.72em;
      color: #445;
      margin-left: 4px;
    }
  </style>
</head>
<body>
  <h1>SAI Mini ML Web App — Regression</h1>
  <div class="subtitle">
    Frontend (HTML/JS) → Backend (FastAPI) → ML (sklearn) → Interactive plot & metrics.
  </div>

  <div class="layout">
    <div class="left">
      <div class="chart-wrapper">
        <canvas id="regChart"></canvas>
      </div>
      <div class="note">
        Blue = data points. Red = fitted curve for chosen model & hyperparameters.
        Orange = prediction at your input x.
      </div>
    </div>

    <div class="right">
      <label for="model">Choose a model</label>
      <select id="model" onchange="toggleHyperparams()">
        <option value="linear">Linear Regression</option>
        <option value="rf">Random Forest</option>
        <option value="svr">SVR (RBF)</option>
      </select>

      <div class="hyper-section" id="linear-hypers">
        <div class="hyper-title">
          Linear Regression
          <span class="pill">no main hyperparameters</span>
        </div>
        <div class="note">
          Fits a straight line. Good baseline to show under/overfitting vs others.
        </div>
      </div>

      <div class="hyper-section" id="rf-hypers" style="display:none;">
        <div class="hyper-title">Random Forest Hyperparameters</div>
        <div class="hyper-row">
          <span>Estimators</span>
          <input id="rf-n-est" type="number" value="200" min="10" step="10" />
        </div>
        <div class="hyper-row">
          <span>Max depth</span>
          <input id="rf-max-depth" type="number" value="" min="1" step="1"
                 placeholder="empty = no limit" />
        </div>
        <div class="note">
          More estimators → smoother & stabler. Limiting depth → less overfitting.
        </div>
      </div>

      <div class="hyper-section" id="svr-hypers" style="display:none;">
        <div class="hyper-title">SVR (RBF) Hyperparameters</div>
        <div class="hyper-row">
          <span>C</span>
          <input id="svr-C" type="number" value="10" step="1" />
        </div>
        <div class="hyper-row">
          <span>gamma</span>
          <input id="svr-gamma" type="number" value="0.5" step="0.1" />
        </div>
        <div class="note">
          Higher C & gamma → more flexible curve (risk of overfit).
        </div>
      </div>

      <label for="xVal">Input x value</label>
      <input id="xVal" type="number" step="0.1" value="5.0" />

      <button onclick="runModel()">Run Model</button>

      <table>
        <thead>
          <tr><th colspan="2">Model Performance (on test set)</th></tr>
        </thead>
        <tbody>
          <tr><td>Model</td><td id="m-name" class="metric-label">-</td></tr>
          <tr><td>R²</td><td id="m-r2">-</td></tr>
          <tr><td>MAE</td><td id="m-mae">-</td></tr>
          <tr><td>Prediction at x</td><td id="m-pred">-</td></tr>
        </tbody>
      </table>

      <div class="note">
        Each click retrains the chosen model with these hyperparameters,
        evaluates on a held-out test set, and sends metrics + curve as JSON.
      </div>
    </div>
  </div>

  <script>
    let regChart = null;
    let rawData = { x: [], y: [] };

    function toggleHyperparams() {
      const model = document.getElementById('model').value;
      document.getElementById('linear-hypers').style.display =
        (model === 'linear') ? 'block' : 'none';
      document.getElementById('rf-hypers').style.display =
        (model === 'rf') ? 'block' : 'none';
      document.getElementById('svr-hypers').style.display =
        (model === 'svr') ? 'block' : 'none';
    }

    async function loadData() {
      const res = await fetch('/init_data');
      rawData = await res.json();
      const ctx = document.getElementById('regChart').getContext('2d');

      if (regChart) {
        regChart.destroy();
      }

      regChart = new Chart(ctx, {
        type: 'scatter',
        data: {
          datasets: [
            {
              label: 'Data Points',
              data: rawData.x.map((x, i) => ({ x: x, y: rawData.y[i] })),
              pointRadius: 3,
            }
          ]
        },
        options: {
          responsive: true,
          maintainAspectRatio: false,
          plugins: { legend: { display: true } },
          scales: {
            x: { title: { display: true, text: 'x' } },
            y: { title: { display: true, text: 'y' } }
          }
        }
      });
    }

    async function runModel() {
      const modelKey = document.getElementById('model').value;
      const xVal = parseFloat(document.getElementById('xVal').value);

      // Read hyperparameters from UI
      let nEst = null, maxDepth = null, C = null, gamma = null;

      if (modelKey === 'rf') {
        const estVal = parseInt(document.getElementById('rf-n-est').value);
        if (!isNaN(estVal) && estVal > 0) nEst = estVal;

        const mdVal = parseInt(document.getElementById('rf-max-depth').value);
        if (!isNaN(mdVal) && mdVal > 0) maxDepth = mdVal;
      }

      if (modelKey === 'svr') {
        const Cval = parseFloat(document.getElementById('svr-C').value);
        if (!isNaN(Cval) && Cval > 0) C = Cval;

        const gVal = parseFloat(document.getElementById('svr-gamma').value);
        if (!isNaN(gVal) && gVal > 0) gamma = gVal;
      }

      const payload = {
        model_name: modelKey,
        x_value: xVal,
        n_estimators: nEst,
        max_depth: maxDepth,
        C: C,
        gamma: gamma
      };

      const res = await fetch('/predict_reg', {
        method: 'POST',
        headers: { 'Content-Type': 'application/json' },
        body: JSON.stringify(payload)
      });
      const data = await res.json();

      if (data.error) {
        alert(data.error);
        return;
      }

      // Update metrics table
      document.getElementById('m-name').textContent = data.model_name_pretty;
      document.getElementById('m-r2').textContent = data.r2.toFixed(3);
      document.getElementById('m-mae').textContent = data.mae.toFixed(3);
      document.getElementById('m-pred').textContent =
        `x = ${data.pred_x.toFixed(2)}, ŷ = ${data.pred_y.toFixed(2)}`;

      if (!regChart) return;

      const lineData = data.line_x.map((x, i) => ({ x: x, y: data.line_y[i] }));
      const predPoint = { x: data.pred_x, y: data.pred_y };

      // Dataset[0] = scatter, [1] = line, [2] = prediction
      if (regChart.data.datasets.length < 2) {
        regChart.data.datasets.push({
          label: 'Fitted Curve',
          type: 'line',
          data: lineData,
          pointRadius: 0,
          borderWidth: 2,
          borderColor: 'red'
        });
      } else {
        regChart.data.datasets[1].data = lineData;
      }

      if (regChart.data.datasets.length < 3) {
        regChart.data.datasets.push({
          label: 'Prediction',
          type: 'scatter',
          data: [predPoint],
          pointRadius: 5,
          backgroundColor: 'orange'
        });
      } else {
        regChart.data.datasets[2].data = [predPoint];
      }

      regChart.update();
    }

    // Initial state
    toggleHyperparams();
    loadData();
  </script>
</body>
</html>"""

`/init_data` is our simplest example of an API endpoint and a perfect place to connect concepts like HTTP methods, JSON, and the frontend–backend contract.

When the browser loads the page, our JavaScript calls `GET /init_data` using `fetch()`. Because this route is decorated with `@app.get("/init_data")`, FastAPI knows that every GET request to that path should run the `init_data` function. Inside it, we return a Python dictionary containing two keys: `"x"` and `"y"`, which hold all the input values and target values from our synthetic regression dataset. FastAPI automatically converts this dictionary into JSON before sending it back to the browser.

On the frontend side, that JSON response is used to draw the blue scatter plot: each `(x, y)` pair becomes a point on the chart. This shows a real, minimal example of a READ-only API:

- it uses the HTTP GET method (we are only requesting data, not changing anything),
- it returns JSON as the common “language” between Python on the server and JavaScript in the browser,
- and it defines a clear contract: `/init_data` will always respond with `{ "x": [...], "y": [...] }`, and our frontend code depends on that structure.

You can use this endpoint to explain that most modern web apps work exactly like this: the UI calls small, well-defined APIs to fetch data, then uses that data to render visualizations or interfaces.


In [7]:
@app.get("/init_data")
def init_data():
    return {"x": X.squeeze().tolist(), "y": y.tolist()}

`/predict_reg` is our main **ML API endpoint**. It shows how a frontend can send inputs and hyperparameters to a backend, trigger real model training and evaluation, and receive structured results as JSON.

Because it’s defined with `@app.post("/predict_reg")`, this route handles HTTP **POST** requests. Unlike `GET /init_data` (which just reads data), here the client sends a **JSON body** describing what it wants the backend to do. FastAPI automatically parses that JSON into a `PredictRegRequest` object, so inside the function we can work with clean Python attributes like `req.model_name`, `req.x_value`, `req.n_estimators`, etc.

The logic is:

1. **Validate the requested model**
   - We check `req.model_name` against `("linear", "rf", "svr")`.
   - If it’s not one of those, we return a JSON error:
     ```json
     { "error": "Invalid model name: ..." }
     ```

2. **Build the model using user-defined hyperparameters**
   - If `model_name == "linear"`:
     - We use plain `LinearRegression` (no main hyperparameters).
   - If `model_name == "rf"`:
     - We read `n_estimators` and `max_depth` from the request (falling back to defaults if missing).
     - We construct a `RandomForestRegressor` with those values.
   - If `model_name == "svr"`:
     - We read `C` and `gamma` from the request (again with sensible defaults).
     - We build an `SVR` with an RBF kernel using those settings.

   This shows how an API can let the **client control hyperparameters**, while the training still happens safely on the backend.

3. **Train and evaluate on the fly**
   - For each request, we fit the chosen model on `X_train, y_train`.
   - We evaluate on `X_test, y_test` to compute:
     - `r2` (how well the model explains the variance),
     - `mae` (average absolute prediction error).
   - This mimics how a backend service might evaluate or monitor models.

4. **Generate data for visualization**
   - We use `grid_x` to compute a smooth sequence of predictions (`line_y`).
   - These values are returned so the frontend can draw the **red fitted curve** for this exact model + hyperparameter choice.

5. **Compute the user’s prediction**
   - We take `req.x_value`,
   - Run `model.predict([[x_val]])`,
   - Return `pred_x` and `pred_y` so the frontend can:
     - show the number in the table,
     - mark the **orange point** on the chart.

The function finally returns a JSON object like:

```json
{
  "model_name_pretty": "Random Forest (n=200, max_depth=None)",
  "r2": 0.89,
  "mae": 2.15,
  "line_x": [...],
  "line_y": [...],
  "pred_x": 5.0,
  "pred_y": 20.3
}


In [8]:
@app.post("/predict_reg")
def predict_reg(req: PredictRegRequest):
    name_map = {
        "linear": "Linear Regression",
        "rf": "Random Forest",
        "svr": "SVR (RBF)"
    }
    if req.model_name not in name_map:
        return {"error": f"Invalid model name: {req.model_name}"}

    # Build & train model with requested hyperparameters
    if req.model_name == "linear":
        model = LinearRegression()
        model_name_pretty = name_map["linear"]

    elif req.model_name == "rf":
        n_estimators = req.n_estimators or 200
        max_depth = req.max_depth if req.max_depth and req.max_depth > 0 else None
        model = RandomForestRegressor(
            n_estimators=n_estimators,
            max_depth=max_depth,
            random_state=0
        )
        model_name_pretty = f"Random Forest (n={n_estimators}, max_depth={max_depth or 'None'})"

    else:  # "svr"
        C = req.C or 10.0
        gamma = req.gamma or 0.5
        model = SVR(kernel="rbf", C=C, gamma=gamma)
        model_name_pretty = f"SVR (RBF, C={C}, gamma={gamma})"

    model.fit(X_train, y_train)

    # Evaluate
    y_test_pred = model.predict(X_test)
    r2 = float(r2_score(y_test, y_test_pred))
    mae = float(mean_absolute_error(y_test, y_test_pred))

    # Fitted curve
    line_y = model.predict(grid_x).tolist()
    line_x = grid_x.squeeze().tolist()

    # Single prediction
    x_val = float(req.x_value)
    y_pred = float(model.predict([[x_val]])[0])

    return {
        "model_name_pretty": model_name_pretty,
        "r2": r2,
        "mae": mae,
        "line_x": line_x,
        "line_y": line_y,
        "pred_x": x_val,
        "pred_y": y_pred,
    }

This cell runs our FastAPI app inside Colab and gives us a link that behaves like `localhost`.

Step by step:

1. `nest_asyncio.apply()`  
   Colab already runs an event loop under the hood. This line patches it so `uvicorn` (our web server) can run cleanly inside the notebook.

2. `PORT = 8002`  
   We choose a port for our server. In local development you’d often see `8000` or `5000`. Here we use `8002` to avoid conflicts, but the idea is the same.

3. “If a previous server exists, signal it to stop”  
   In a notebook, you might re-run this cell several times as you tweak code.  
   The `try/except` block:
   - checks if a `server` object exists,
   - if it does, sets `server.should_exit = True` to shut it down,
   - this prevents the “address already in use” errors.
   This is just notebook hygiene to avoid restarting the whole runtime.

4. `run()` function and background thread  
   We configure and start a `uvicorn.Server` with:
   - `app` → our FastAPI application,
   - `host="0.0.0.0"` and `port=PORT` → listen on that port inside the Colab environment.
   
   We run it in a **daemon thread** so:
   - the server keeps running in the background,
   - the notebook cell can finish and let us keep working.

5. `proxyPort` → get a browser URL  
   Colab doesn’t expose `localhost:8002` directly, so we ask:

   ```python
   proxy_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT})")


In [9]:
nest_asyncio.apply()

PORT = 8002  # keep fixed, or change if you really want a new one

# If a previous server exists, signal it to stop
try:
    server.should_exit = True
    time.sleep(1)
except NameError:
    pass

def run():
    global server
    config = uvicorn.Config(app, host="0.0.0.0", port=PORT, log_level="info")
    server = uvicorn.Server(config)
    server.run()

thread = threading.Thread(target=run, daemon=True)
thread.start()

time.sleep(2)

proxy_url = output.eval_js(f"google.colab.kernel.proxyPort({PORT})")
print("App URL:", proxy_url)
print("Docs URL:", proxy_url + "/docs")


INFO:     Started server process [298]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://0.0.0.0:8002 (Press CTRL+C to quit)


App URL: https://8002-m-s-14ocfjy7fw9wx-b.us-central1-1.prod.colab.dev
Docs URL: https://8002-m-s-14ocfjy7fw9wx-b.us-central1-1.prod.colab.dev/docs
