# Coinbase Double DQN Training Pipeline

This notebook implements a Deep Q-Network (DQN) with enhancements such as:
- **Prioritized Experience Replay (PER)**
- **Batch Normalization**
- **Live market data integration via WebSockets**
- **Signal handling for graceful shutdown**

The goal is to train a robust reinforcement learning agent to identify profitable trades in real-time USDC-crypto markets using Coinbase data.


# Coinbase Dueling DQN Agent
A professional training notebook using real-time USDC price data from Coinbase.

## Initialize Signal Handling
Gracefully manage interruptions and shutdowns during training or data streaming.

In [None]:
import threading
import signal
import time
import logging
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import requests
import json
from collections import deque
from datetime import datetime
import random
import matplotlib.pyplot as plt

# Constants
GAMMA = 0.99
LR = 1e-3
EPSILON = 1.0
EPSILON_DECAY = 0.995
MIN_EPSILON = 0.01
BATCH_SIZE = 32
BUFFER_CAPACITY = 10000
TARGET_UPDATE = 10
CACHE_FILE = "coinbase_cache.json"
API_URL = "https://api.coinbase.com/v2/prices/USDC-USD/spot"
SAVE_MODEL_FREQ = 50
SAVE_BEST_MODEL = True

# Logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("coinbase_dqn")

## Define Replay Buffer with Prioritized Experience Replay (PER)
Efficient sampling mechanism that prioritizes high-error transitions.

In [None]:
class DuelingDQN(nn.Module):
    def __init__(self):
        super(DuelingDQN, self).__init__()
        self.fc1 = nn.Linear(4, 128)
        self.fc2 = nn.Linear(128, 128)
        self.value_fc = nn.Linear(128, 1)
        self.advantage_fc = nn.Linear(128, 2)
        self.bn1 = nn.BatchNorm1d(128)
        self.bn2 = nn.BatchNorm1d(128)

    def forward(self, x):
        x = F.relu(self.bn1(self.fc1(x)))
        x = F.relu(self.bn2(self.fc2(x)))
        value = self.value_fc(x)
        advantage = self.advantage_fc(x)
        return value + advantage - advantage.mean()

class PrioritizedReplayBuffer:
    def __init__(self, capacity, alpha=0.6, beta=0.4):
        self.buffer = deque(maxlen=capacity)
        self.alpha = alpha
        self.beta = beta
        self.priorities = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, td_error):
        self.buffer.append((state, action, reward, next_state, td_error))
        self.priorities.append(max(td_error, 1e-6))

    def sample(self, batch_size):
        if len(self.buffer) < batch_size:
            return []
        priorities = np.array(self.priorities) ** self.alpha
        probs = priorities / priorities.sum()
        indices = np.random.choice(len(self.buffer), batch_size, p=probs)
        batch = [self.buffer[idx] for idx in indices]
        weights = (len(self.buffer) * probs[indices]) ** (-self.beta)
        weights /= weights.max()
        return batch, weights, indices

    def update_priorities(self, indices, td_errors):
        for idx, td_error in zip(indices, td_errors):
            self.priorities[idx] = max(td_error, 1e-6)

    def __len__(self):
        return len(self.buffer)

## Build the DQN Model
Includes convolutional layers with Batch Normalization for improved training stability.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
policy_net = DuelingDQN().to(device)
target_net = DuelingDQN().to(device)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.Adam(policy_net.parameters(), lr=LR)
buffer = PrioritizedReplayBuffer(BUFFER_CAPACITY)
step_count = 0
best_model = None
best_loss = float('inf')
shutdown_event = threading.Event()

def fetch_data():
    try:
        response = requests.get(API_URL, timeout=5)
        response.raise_for_status()
        data = response.json()
        return float(data.get('data', {}).get('amount', 0))
    except Exception as e:
        logger.error(f"Data fetch failed: {e}")
        return None

def save_cache():
    try:
        with open(CACHE_FILE, 'w') as f:
            json.dump([exp for exp in buffer.buffer], f)
        logger.info(f"Cache saved to {CACHE_FILE}")
    except Exception as e:
        logger.error(f"Failed to save cache: {e}")

def save_best_model():
    global best_model, best_loss
    current_loss = calculate_loss(torch.zeros(1, 4).to(device), [0], [0], torch.zeros(1, 4).to(device))
    if current_loss < best_loss:
        best_loss = current_loss
        best_model = policy_net.state_dict()
        torch.save(best_model, "best_dqn_model.pth")
        logger.info("New best model saved!")

## Configure Optimizer and Loss
Set up the optimizer and loss function for model training.

In [None]:
def calculate_loss(states, actions, rewards, next_states):
    states = torch.FloatTensor(np.array(states)).to(device)
    next_states = torch.FloatTensor(np.array(next_states)).to(device)
    actions = torch.LongTensor(actions).unsqueeze(1).to(device)
    rewards = torch.FloatTensor(rewards).unsqueeze(1).to(device)

    q_values = policy_net(states).gather(1, actions)
    q_next = target_net(next_states).max(1)[0].unsqueeze(1).detach()
    target = rewards + (GAMMA * q_next)

    loss = F.mse_loss(q_values, target)
    return loss

def compute_td_error(state, action, reward, next_state):
    state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
    next_state_tensor = torch.FloatTensor(next_state).unsqueeze(0).to(device)
    q_value = policy_net(state_tensor)[0, action]
    with torch.no_grad():
        q_next = target_net(next_state_tensor).max(1)[0].detach()
    return reward + GAMMA * q_next.item() - q_value.item()

def select_action(state):
    global EPSILON
    if np.random.rand() < EPSILON:
        return np.random.choice([0, 1])
    else:
        with torch.no_grad():
            state_tensor = torch.FloatTensor(state).unsqueeze(0).to(device)
            return policy_net(state_tensor).argmax().item()

def update_model(batch, weights):
    states, actions, rewards, next_states, _ = zip(*batch)
    loss = calculate_loss(states, actions, rewards, next_states)
    weights_tensor = torch.FloatTensor(weights).to(device)
    loss = (loss * weights_tensor).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

def update_epsilon(step_count):
    global EPSILON
    if EPSILON > MIN_EPSILON:
        EPSILON *= EPSILON_DECAY

## Training Loop
Manages environment interaction, model updates, and experience replay.

In [None]:
reward_history = []

def train_loop():
    global step_count
    for episode in range(1, 101):
        state = np.random.rand(4)
        total_reward = 0
        done = False
        while not done:
            action = select_action(state)
            next_price = fetch_data()
            if next_price is None:
                continue
            next_state = np.array([next_price, state[1], state[2], state[3]])
            reward = 0.1 if action == 0 else -0.1
            total_reward += reward

            td_error = compute_td_error(state, action, reward, next_state)
            buffer.push(state, action, reward, next_state, td_error)

            if len(buffer) >= BATCH_SIZE:
                batch, weights, indices = buffer.sample(BATCH_SIZE)
                update_model(batch, weights)

            state = next_state
            step_count += 1
            update_epsilon(step_count)

            if step_count % TARGET_UPDATE == 0:
                target_net.load_state_dict(policy_net.state_dict())
            if step_count % SAVE_MODEL_FREQ == 0:
                torch.save(policy_net.state_dict(), f"dqn_model_{step_count}.pth")
            if SAVE_BEST_MODEL:
                save_best_model()

            done = np.random.rand() < 0.01
        reward_history.append(total_reward)
    logger.info("Training finished.")

## Run Main Training Routine
Execute the training loop, stream live data, and ensure graceful exits.

In [None]:
plt.figure(figsize=(10,5))
plt.plot(reward_history)
plt.title('Total Rewards per Episode')
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.grid(True)
plt.show()

## Run Main Training Routine
Execute the training loop, stream live data, and ensure graceful exits.

In [None]:
train_loop()