# Quantization with RL-Based LLM Routing (via Ollama + Stable-Baselines3)
This notebook demonstrates a lightweight reinforcement learning (RL) environment that routes queries to the most suitable local LLM using Stable-Baselines3 and Ollama.

### Install Required Libraries

In [None]:
!pip install gymnasium stable-baselines3 langchain_community

#### Environment and Model Setup

In [None]:
import gymnasium as gym
from stable_baselines3 import PPO
from langchain_community.llms import Ollama
import random

# Define local LLMs via Ollama
llms = {
    "llama3": Ollama(model="llama3"),
    "medllama": Ollama(model="medllama2")
}

### Define the Routing Environment

In [None]:
class LLMRoutingEnv(gym.Env):
    def __init__(self):
        self.action_space = gym.spaces.Discrete(2)  # Two LLMs
        self.observation_space = gym.spaces.Box(-1, 1, (384,))  # Mock embedding size

    def step(self, action):
        selected_model = list(llms.keys())[action]
        reward = 1.0 if (action == 1 and "diabetes" in self.current_query) else 0.2
        return self._get_obs(), reward, False, {}

    def reset(self):
        self.current_query = random.choice([
            "Diabetes management guidelines",
            "Python web scraping tutorial"
        ])
        return self._get_obs()

    def _get_obs(self):
        # Mock embedding (normally would be from SentenceTransformer or other encoder)
        return [random.uniform(-1, 1) for _ in range(384)]


### Train the RL Agent

In [None]:
env = LLMRoutingEnv()
model = PPO("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=1000)

### Test Deployment

In [None]:
obs = env.reset()
action, _ = model.predict(obs)
print(f"Optimal model: {list(llms.keys())[action]}")

## Expected Behavior

- **Reward signal**: `+1.0` if **"diabetes"**-related query is routed to **medllama2**

- **Learned policy**: Prioritizes **medllama2** for medical queries

### Factors Considered:

- **Query context** (keyword: `"diabetes"`)

- **Historical performance**:
  - **medllama2**: 92% accuracy (medical tasks)
  - **llama3**: 68% accuracy (medical tasks)
