# Multi-Agent Deep Variational Reinforcement Learning Architecture

This notebook provides a detailed visualization of the MA-DVRL architecture, including network components, mathematical foundations, and training process.

## Setup
First, let's import the required packages for visualization:

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from IPython.display import display, Markdown, Latex
import graphviz

## Overall Architecture

The MA-DVRL model consists of several key components:
1. Encoder Network (for private beliefs)
2. Inference Network (for opponent modeling)
3. Policy Network (for action selection)
4. Value Network (for state-value estimation)

Let's visualize each component:

In [None]:
def create_architecture_diagram():
    dot = graphviz.Digraph(comment='MA-DVRL Architecture')
    dot.attr(rankdir='LR')
    
    # Add nodes
    with dot.subgraph(name='cluster_0') as c:
        c.attr(label='Encoder Network')
        c.node('obs', 'Observation\n(cards, pot, etc.)')
        c.node('enc1', 'FC Layer\n(256)')
        c.node('enc2', 'FC Layer\n(256)')
        c.node('belief', 'Belief State\n(256)')
    
    with dot.subgraph(name='cluster_1') as c:
        c.attr(label='Inference Network')
        c.node('opp_obs', 'Opponent Actions\n+ Public Info')
        c.node('inf1', 'FC Layer\n(256)')
        c.node('inf2', 'FC Layer\n(256)')
        c.node('opp_belief', 'Opponent Belief\n(256)')
    
    with dot.subgraph(name='cluster_2') as c:
        c.attr(label='Policy Network')
        c.node('joint', 'Joint Belief\n(512)')
        c.node('pol1', 'FC Layer\n(256)')
        c.node('pol2', 'FC Layer\n(256)')
        c.node('action', 'Action Probs\n(4)')
    
    with dot.subgraph(name='cluster_3') as c:
        c.attr(label='Value Network')
        c.node('val1', 'FC Layer\n(256)')
        c.node('val2', 'FC Layer\n(256)')
        c.node('value', 'State Value\n(1)')
    
    # Add edges
    dot.edge('obs', 'enc1')
    dot.edge('enc1', 'enc2')
    dot.edge('enc2', 'belief')
    
    dot.edge('opp_obs', 'inf1')
    dot.edge('inf1', 'inf2')
    dot.edge('inf2', 'opp_belief')
    
    dot.edge('belief', 'joint')
    dot.edge('opp_belief', 'joint')
    dot.edge('joint', 'pol1')
    dot.edge('pol1', 'pol2')
    dot.edge('pol2', 'action')
    
    dot.edge('joint', 'val1')
    dot.edge('val1', 'val2')
    dot.edge('val2', 'value')
    
    return dot

architecture = create_architecture_diagram()
display(architecture)

## Mathematical Foundations

### 1. Variational Inference

The MA-DVRL model uses variational inference to learn belief states. The key equations are:

**Prior Distribution (Encoder)**:
$$p(z_t|o_t) = \mathcal{N}(\mu_\phi(o_t), \sigma_\phi(o_t))$$

where:
- $z_t$ is the belief state at time t
- $o_t$ is the observation at time t
- $\phi$ are the encoder network parameters

**Posterior Distribution (Inference)**:
$$q(z_t|o_{1:t}, a_{1:t-1}) = \mathcal{N}(\mu_\theta(h_t), \sigma_\theta(h_t))$$

where:
- $h_t$ is the history encoding up to time t
- $\theta$ are the inference network parameters

**ELBO Loss**:
$$\mathcal{L}_{ELBO} = \mathbb{E}_{q(z)}[\log p(o|z)] - D_{KL}(q(z|o,a)||p(z|o))$$

### 2. Policy Learning

The policy is learned using the following objective:

$$\mathcal{L}_{policy} = \mathbb{E}_{\pi_\psi}[\sum_{t=0}^T r_t] + \alpha H(\pi_\psi)$$

where:
- $\pi_\psi$ is the policy network with parameters $\psi$
- $r_t$ is the reward at time t
- $H(\pi_\psi)$ is the entropy of the policy
- $\alpha$ is the entropy coefficient

### 3. Value Function

The value function is learned using TD learning:

$$\mathcal{L}_{value} = \mathbb{E}[(V_\omega(s_t) - (r_t + \gamma V_\omega(s_{t+1})))^2]$$

where:
- $V_\omega$ is the value network with parameters $\omega$
- $\gamma$ is the discount factor

In [None]:
def create_training_flow_diagram():
    dot = graphviz.Digraph(comment='MA-DVRL Training Flow')
    dot.attr(rankdir='TB')
    
    # Add nodes
    dot.node('obs', 'Observation')
    dot.node('enc', 'Encoder\nNetwork')
    dot.node('inf', 'Inference\nNetwork')
    dot.node('belief', 'Belief\nState')
    dot.node('policy', 'Policy\nNetwork')
    dot.node('value', 'Value\nNetwork')
    dot.node('action', 'Action')
    dot.node('env', 'Environment')
    dot.node('reward', 'Reward')
    
    # Add edges with labels
    dot.edge('obs', 'enc', 'o_t')
    dot.edge('obs', 'inf', 'o_{1:t}')
    dot.edge('enc', 'belief', 'p(z_t|o_t)')
    dot.edge('inf', 'belief', 'q(z_t|o,a)')
    dot.edge('belief', 'policy', 'z_t')
    dot.edge('belief', 'value', 'z_t')
    dot.edge('policy', 'action', '\pi(a_t|z_t)')
    dot.edge('action', 'env')
    dot.edge('env', 'reward')
    dot.edge('env', 'obs', 'o_{t+1}')
    
    return dot

training_flow = create_training_flow_diagram()
display(training_flow)

## Network Details

### Encoder Network Architecture
```
Input (obs_dim) → FC(256) → ReLU → FC(256) → ReLU → 
                                              → FC(256) → μ
                                              → FC(256) → σ
```

### Inference Network Architecture
```
Input (hist_dim) → FC(256) → ReLU → FC(256) → ReLU → 
                                              → FC(256) → μ
                                              → FC(256) → σ
```

### Policy Network Architecture
```
Input (512) → FC(256) → ReLU → FC(256) → ReLU → FC(4) → Softmax
```

### Value Network Architecture
```
Input (512) → FC(256) → ReLU → FC(256) → ReLU → FC(1)
```

## Training Process

1. **Forward Pass**:
   - Encode observation to get prior belief
   - Use inference network to get posterior belief
   - Sample belief state from posterior
   - Get action probabilities and value estimate

2. **Action Selection**:
   - Sample action from policy
   - Execute in environment
   - Store transition (o, a, r, o')

3. **Backward Pass**:
   - Compute ELBO loss
   - Compute policy gradient loss
   - Compute value loss
   - Update all networks

The total loss is:

$$\mathcal{L}_{total} = \mathcal{L}_{ELBO} + \lambda_1\mathcal{L}_{policy} + \lambda_2\mathcal{L}_{value}$$

where $\lambda_1$ and $\lambda_2$ are weighting coefficients.