# Solving Mean Field Games (Deep Galerkin)
#### Frederik Kelbel, Imperial College London

In [1]:
import torch
import plotly.graph_objects as go
import numpy as np
from operators import div, Δ, D, mdotb, bdotm, mdotm, bdotb, m, p, H, cat
from DGM import DGMPIASolver, DeepPDESolver
from pdes import HBJ, PDE
from scipy.integrate import quad
from plotly.subplots import make_subplots
from configs import CONFIG_HBJS as MODEL_CONFIG

In [2]:
def plot_losses(losses, avg_over=10):
    avgs = np.convolve(losses, np.ones(avg_over), 'valid') / avg_over
    fig = make_subplots(rows=1, cols=1)
    fig.add_trace(go.Scatter(x=np.arange(len(avgs)), y=avgs, mode='lines', name="Error at x=0.1"), row=1, col=1)
    fig.update_layout(
        title="Loss",
        xaxis_title="Iterations",
        yaxis_title="Loss",
        font=dict(
            family="Courier New, monospace",
            size=14
        )
    )
    fig.show()

In [128]:
class OPINION(PDE):
    def __init__(self):
        super().__init__()
        x_d = -0.5
        self.beta = -1
        self.sigma = 0.0#1
        self.gamma = 0.05
        N = 10
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1
        def some(u, var):
            return div(u, var[-1]) + sum([div(u, x_i)*sum([(self.beta/N)*(1-x_i**2)*(x_j-x_i) for x_j in var[:-1]]) 
                                          for x_i in var[:-1]])\
                        + self.sigma*torch.sum(H(u, var[:-1]), dim=(-1, -2)).unsqueeze(-1) \
                        + sum([-(N/(2*self.gamma))*div(u, x_i)**2 for x_i in var[:-1]])
            
        def old_som(u, var):
            return div(u, var[-1]) + torch.sum((cat(var[:-1]) - x_d)**2, dim=-1, keepdims=True)/(2*N) \
                        + torch.sum((self.beta*(1 - cat(var[:-1])**2)
                                     *(sum(var[:-1])/N - cat(var[:-1]))
                                     * D(u, var[:-1])), dim=-1, keepdims=True)\
                        + self.sigma*torch.sum(H(u, var[:-1]), dim=(-1, -2)).unsqueeze(-1) \
                        - (N/(2*self.gamma))* torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
            
        self.equation = lambda u, var: some(u, var)
        self.domain_func = [(lambda var: [(v-0.5)*2 for v in var[:-1]] + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u]
        self.boundary_func = [(lambda var: [(v-0.5)*2 for v in var[:-1]] + [0*var[-1]+5], 64)]

In [129]:
eq = OPINION()
model = HEAT_MODEL_CONFIG = {
    "hidden_dim": 64,
    "learning_rate": 5e-3,
    "loss_weights": (2, 1),
    "sampling_method": "uniform",
    "lr_decay": 0.98,
    "network_type": "FF",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(400))
plot_losses(losses)

100%|██████████| 400/400 [00:17<00:00, 22.49 it/s]


In [130]:
def plot_opinions(T):
    n = int(50*T)
    N = eq.var_dim-1
    def sde_sum(X):
        return eq.beta*(1-X**2)*(sum(X)/N-X)
    c_xs = np.zeros((N, n))
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.8).clip(-1, 1)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.1 +0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(N/(eq.gamma))*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [131]:
c_xs, uc_xs = plot_opinions(5)

In [132]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=c_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)), row=1, col=1)
fig.add_trace(go.Scatter(x=uc_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8), row=1, col=1)
fig.update_layout(
    title="Opinion distribution",
    xaxis_title="X",
    yaxis_title="Num",
    xaxis_range=[-1.1,1.1],
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

In [133]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()