# Making a DQN for our SDN system

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np

class DQN(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(state_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, action_dim)
        )

    def forward(self, x):
        return self.net(x)
    
    def compute_reward(packet, delivered=True):
        if packet.dropped:
            return -1.0  # penalty for dropping benign packet
        elif packet.suspicion >= upper_thresh:
            return 1.0   # reward for correctly blocking malicious packet
        else:
            # reward = negative latency to encourage fast delivery
            return -packet.finish_t + packet.spawn_t


In [None]:
class DQNRouter:
    def __init__(self, state_dim, action_dim):
        # define Q-network here
        pass
    
    def select_path(self, topology_state, packet, mode="normal"):
        """
        mode = "normal" → choose optimal route (minimize latency, load balance, etc.)
        mode = "quarantine" → choose path to isolation node/subnet
        """
        state = self.encode_state(topology_state, packet)
        action = self.q_network(state).argmax().item()

        # decode action into actual path
        if mode == "normal":
            return self.decode_action_to_path(action, quarantine=False)
        else:
            return self.decode_action_to_path(action, quarantine=True)

    def encode_state(self, topology_state, packet):
        # turn topology state + packet features into NN input vector
        pass

    def decode_action_to_path(self, action, quarantine=False):
        # map action index → path through topology
        pass


In [None]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from collections import deque
import pandas as pd

# ============================
# Replay Buffer
# ============================
class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)

    def push(self, state, action, reward, next_state, done):
        self.buffer.append((state, action, reward, next_state, done))

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.array, zip(*batch))
        return (
            torch.FloatTensor(state),
            torch.LongTensor(action),
            torch.FloatTensor(reward),
            torch.FloatTensor(next_state),
            torch.FloatTensor(done)
        )

    def __len__(self):
        return len(self.buffer)

# ============================
# DQN Model
# ============================
class DQN(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=128):
        super(DQN, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim)
        )

    def forward(self, x):
        return self.net(x)

# ============================
# Agent
# ============================
class DQNAgent:
    def __init__(self, input_dim, output_dim, lr=1e-3, gamma=0.99, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.q_network = DQN(input_dim, output_dim).to(self.device)
        self.target_network = DQN(input_dim, output_dim).to(self.device)
        self.target_network.load_state_dict(self.q_network.state_dict())
        self.optimizer = optim.Adam(self.q_network.parameters(), lr=lr)

        self.gamma = gamma
        self.epsilon = eps_start
        self.eps_end = eps_end
        self.eps_decay = eps_decay
        self.memory = ReplayBuffer()
        self.steps_done = 0
        self.output_dim = output_dim

    def select_action(self, state):
        if random.random() < self.epsilon:
            return random.randrange(self.output_dim)
        else:
            state = torch.FloatTensor(state).unsqueeze(0).to(self.device)
            with torch.no_grad():
                q_values = self.q_network(state)
            return q_values.argmax().item()

    def optimize_model(self, batch_size=64):
        if len(self.memory) < batch_size:
            return

        states, actions, rewards, next_states, dones = self.memory.sample(batch_size)
        states, actions, rewards, next_states, dones = (
            states.to(self.device), actions.to(self.device), rewards.to(self.device),
            next_states.to(self.device), dones.to(self.device)
        )

        q_values = self.q_network(states).gather(1, actions.unsqueeze(1)).squeeze(1)
        next_q_values = self.target_network(next_states).max(1)[0]
        expected_q_values = rewards + self.gamma * next_q_values * (1 - dones)

        loss = nn.MSELoss()(q_values, expected_q_values.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # Epsilon decay
        self.epsilon = max(self.eps_end, self.epsilon * self.eps_decay)

    def update_target(self):
        self.target_network.load_state_dict(self.q_network.state_dict())

# ============================
# CSV Stream Environment
# ============================
class CSVStreamEnvironment:
    def __init__(self, csv_path):
        self.data = pd.read_csv(csv_path)
        self.index = 0
        self.max_index = len(self.data)

    def reset(self):
        self.index = 0
        return self._get_state()

    def step(self, action):
        # Example reward logic:
        # 0 = route, 1 = quarantine, 2 = drop
        current_state = self._get_state()
        suspicion_score = current_state[-1]  # assuming last feature = suspicion

        if action == 0:  # route
            reward = 1.0 if suspicion_score < 0.5 else -1.0
        elif action == 1:  # quarantine
            reward = 0.5 if suspicion_score >= 0.5 else -0.5
        else:  # drop
            reward = -0.2 if suspicion_score < 0.5 else 0.5

        self.index += 1
        done = self.index >= self.max_index
        next_state = self._get_state() if not done else np.zeros_like(current_state)

        return next_state, reward, done

    def _get_state(self):
        row = self.data.iloc[self.index].values
        return row.astype(np.float32)

# ============================
# Glue Training Loop
# ============================
def train_dqn(csv_path, episodes=10, batch_size=64, target_update=10):
    env = CSVStreamEnvironment(csv_path)
    state_dim = env.reset().shape[0]
    action_dim = 3  # route, quarantine, drop
    agent = DQNAgent(state_dim, action_dim)

    for ep in range(episodes):
        state = env.reset()
        total_reward = 0
        done = False

        while not done:
            action = agent.select_action(state)
            next_state, reward, done = env.step(action)
            agent.memory.push(state, action, reward, next_state, done)
            state = next_state
            total_reward += reward

            agent.optimize_model(batch_size)

        if ep % target_update == 0:
            agent.update_target()

        print(f"Episode {ep+1}/{episodes}, Total Reward: {total_reward:.2f}, Epsilon: {agent.epsilon:.3f}")

    torch.save(agent.q_network.state_dict(), "dqn_model.pth")
    print("Model saved as dqn_model.pth")
