# Imports

In [1]:
%load_ext sql
%sql sqlite:///TestQ.db

import pandas as pd
import sqlite3
# Connect to DB
conn = sqlite3.connect("TestQ.db")

# Question

Let’s say we want to improve the matching algorithm for drivers and riders for Uber. The engineering team has added a column to the `drivers` table with a weighted value for better matching.

**Given this table of drivers, write a query to perform a weighted random selection of a driver based on the driver weight.**

In [2]:
%%sql

CREATE TABLE drivers (
    driver_id INTEGER PRIMARY KEY,
    weight INTEGER
);

INSERT INTO drivers (driver_id, weight) VALUES
(1, 5),
(2, 10),
(3, 15),
(4, 20),
(5, 5),
(6, 10),
(7, 10),
(8, 5),
(9, 10),
(10, 1000);

 * sqlite:///TestQ.db
Done.
10 rows affected.


[]

# Answer

In [3]:
query = """
WITH driver_weights AS (
    SELECT 
        driver_id,
        weight,
        SUM(weight) OVER () AS total_weight,
        SUM(weight) OVER (ORDER BY driver_id ASC) AS cum_weight
    FROM drivers
)

SELECT * FROM driver_weights;
"""

df = pd.read_sql_query(query, conn)
df

Unnamed: 0,driver_id,weight,total_weight,cum_weight
0,1,5,1090,5
1,2,10,1090,15
2,3,15,1090,30
3,4,20,1090,50
4,5,5,1090,55
5,6,10,1090,65
6,7,10,1090,75
7,8,5,1090,80
8,9,10,1090,90
9,10,1000,1090,1090


In [27]:
query = """
WITH driver_weights AS (
    SELECT 
        driver_id,
        weight,
        SUM(weight) OVER () AS total_weight,
        SUM(weight) OVER (ORDER BY driver_id ASC) AS cum_weight
    FROM drivers
)

SELECT 
    *, 
    (RANDOM() / (9223372036854775807 * 1.0) + 1) / 2 as random_num
FROM 
    driver_weights where cum_weight > total_weight * (RANDOM() / (9223372036854775807 * 1.0) + 1) / 2 limit 1;
"""

# You can use the function RAND() directly to get number between 0 and 1

df = pd.read_sql_query(query, conn)
df

Unnamed: 0,driver_id,weight,total_weight,cum_weight,random_num
0,10,1000,1090,1090,0.146621


### 🎯 Goal
We want to randomly select one driver from a table, but the probability of selecting each driver should be **proportional to their weight**.

---

### 👣 Step-by-Step Explanation

#### 1. **Visualizing Weights as a Number Line**
Imagine we have drivers and their weights:

| Driver | Weight |
|--------|--------|
| 1      | 5      |
| 2      | 10     |
| 3      | 15     |

This gives a total weight of **30**. Now imagine a number line from `0` to `30`, divided like so:

- Driver 1 owns [0, 5)
- Driver 2 owns [5, 15)
- Driver 3 owns [15, 30)

Each driver "owns" a portion of the line proportional to their weight.

---

#### 2. **Random Selection on the Line**
We generate a random number `r` between `0` and `30`. The driver whose interval contains `r` is selected. This ensures drivers with higher weights are more likely to be picked.

---

#### 3. **SQL Query Replication**
Here's the SQL query that implements this logic:

```sql
WITH driver_weights AS (
    SELECT 
        driver_id,
        weight,
        SUM(weight) OVER () AS total_weight,
        SUM(weight) OVER (ORDER BY driver_id ASC) AS cum_weight
    FROM drivers
)
SELECT driver_id
FROM driver_weights 
WHERE cum_weight > total_weight * RAND()
ORDER BY cum_weight
LIMIT 1;
