# Solving Mean Field Games (Deep Galerkin)
#### Frederik Kelbel, Imperial College London

In [160]:
import torch
import plotly.graph_objects as go
import numpy as np
from operators import div, Δ, D, mdotb, bdotm, mdotm, bdotb, m, p, H, cat
from DGM import DGMPIASolver, DeepPDESolver
from pdes import HBJ, PDE
from scipy.integrate import quad
from plotly.subplots import make_subplots
from configs import CONFIG_HBJS as MODEL_CONFIG
from FBSDEs import FBSDESolver
from pdes import FBSDE

In [161]:
def plot_losses(losses, avg_over=10):
    avgs = np.convolve(losses, np.ones(avg_over), 'valid') / avg_over
    fig = make_subplots(rows=1, cols=1)
    fig.add_trace(go.Scatter(x=np.arange(len(avgs)), y=avgs, mode='lines', name="Error at x=0.1"), row=1, col=1)
    fig.update_layout(
        title="Loss",
        xaxis_title="Iterations",
        yaxis_title="Loss",
        font=dict(
            family="Courier New, monospace",
            size=14
        )
    )
    fig.show()

In [195]:
def sampling(var):
    means = (torch.rand((var[-1].shape[0], 1))-0.5)
    return [(v-0.5) + means for v in var]

class LQG(PDE):
    def __init__(self):
        super().__init__()
        self.l = 1.0
        self.sigma = 0.0
        self.gamma = 0.05
        N = 3
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1
            
        self.equation = lambda u, var: div(u, var[-1]) + (1/2)*self.sigma*torch.sum(H(u, var[:-1]), dim=(-1, -2)).unsqueeze(-1) - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        self.domain_func = [(lambda var: sampling(var[:-1]) + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u - self.l*torch.sum((sum(var[:-1])/N-cat(var[:-1]))**2, dim=-1, keepdims=True)/2] 
        self.boundary_func = [(lambda var: sampling(var[:-1]) + [0.1*var[-1]+5], 128)]

In [196]:
eq = LQG()
model = MODEL_CONFIG = {
    "hidden_dim": 64,
    "learning_rate": 1e-3,
    "loss_weights": (1, 4),
    "sampling_method": "uniform",
    "lr_decay": 0.98,
    "network_type": "FF",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(600))
plot_losses(losses)

100%|██████████| 600/600 [00:08<00:00, 68.34 it/s]


In [197]:
def plot_game(T):
    n = int(40*T)
    N = eq.var_dim-1
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.8).clip(-0.7, 0.9)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.1 +0.5).clip(0.4, 0.6)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/eq.gamma)*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i]  + np.sqrt(eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + np.sqrt(eq.sigma)*dW
        
    return c_xs, uc_xs

In [198]:
c_xs, uc_xs = plot_game(5)

In [199]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

# ---------

In [152]:
def sampling(var):
    means = (torch.rand((var[-1].shape[0], 1))-0.5)
    return [(v-0.5) + means for v in var]
          
def init_sampling(var):
    var[0] = (torch.randn((64, 1))*0.1 -0.8).clamp(-0.7, 0.9)
    var[1] = (torch.randn((64, 1))*0.1 -0.8).clamp(-0.7, 0.9)
    var[2] = (torch.randn((64, 1))*0.1 +0.5).clamp(0.4, 0.6)
    var[3] = (torch.randn((64, 1))*0.1 +0.5).clamp(0.4, 0.6)
    var[4] = (torch.randn((64, 1))*0.1 +0.5).clamp(0.4, 0.6)
    return var
        
class OPINION(PDE):
    def __init__(self):
        super().__init__()
        x_d = 0.0
        self.beta = -1
        self.sigma = 0.01
        self.gamma = 0.05
        N = 5
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1
        def some(u, var):
            return div(u, var[-1]) + torch.sum((cat(var[:-1]) - x_d)**2, dim=-1, keepdims=True)/(2*N)\
                        + sum([self.beta*(1-x_i**2)*(sum(var[:-1])/N-x_i)*div(u, x_i) for x_i in var[:-1]])\
                        + self.sigma*torch.sum(H(u, var[:-1]), dim=(-1, -2)).unsqueeze(-1) \
                        + sum([-(N/(2*self.gamma))*div(u, x_i)**2 for x_i in var[:-1]])
        
        def ex(u, var):
            return div(u, var[-1]) + torch.sum((cat(var[:-1]) - x_d)**2, dim=-1, keepdims=True)/(2) \
                        + torch.sum((self.beta*(1 - cat(var[:-1])**2)
                                     *(torch.mean(cat(var[:-1]), dim=-1, keepdims=True) - cat(var[:-1]))
                                     * D(u, var[:-1])), dim=-1, keepdims=True)\
                        + self.sigma*torch.sum(H(u, var[:-1]), dim=(-1, -2)).unsqueeze(-1) \
                        - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        
            
        self.equation = lambda u, var: ex(u, var)
        self.domain_func = [(lambda var: sampling(var[:-1]) + [5*var[-1]], 128),
                           (lambda var: init_sampling(var[:-1]) + [0.1*var[-1]], 64)]
        self.boundary_cond = [lambda u, var: u - torch.sum((cat(var[:-1]) - x_d)**2, dim=-1, keepdims=True)/(2)] 
        self.boundary_func = [(lambda var: sampling(var[:-1]) + [0.1*var[-1]+5], 64)]

In [153]:
import os
path = os.path.abspath(os.getcwd()) + "/opinion.pth"

eq = OPINION()
eq.domain_func = [(lambda var: sampling(var[:-1]) + [5*var[-1]], 128)]
eq.boundary_func = [(lambda var: [(v-0.5)*2.2 for v in var[:-1]] + [5*var[-1]], 64)]
model = MODEL_CONFIG = {
    "hidden_dim": 64,
    "learning_rate": 1e-3,
    "loss_weights": (0, 1),
    "sampling_method": "uniform",
    "lr_decay": 0.98,
    "network_type": "FF",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(500))
solver.save(path)
plot_losses(losses)

100%|██████████| 500/500 [00:11<00:00, 44.39 it/s]


In [154]:
eq = OPINION()
model = MODEL_CONFIG = {
    "hidden_dim": 64,
    "learning_rate": 1e-3,
    "loss_weights": (1, 1),
    "sampling_method": "uniform",
    "lr_decay": 0.98,
    "network_type": "FF",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
solver.load(path)
losses = list(solver.train(600))
plot_losses(losses)

100%|██████████| 600/600 [00:13<00:00, 45.27 it/s]


In [155]:
def plot_opinions(T):
    n = int(50*T)
    N = eq.var_dim-1
    def sde_sum(X):
        return eq.beta*(1-X**2)*(sum(X)/N-X)
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.8).clip(-0.7, 0.9)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.1 +0.5).clip(0.4, 0.6)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/(eq.gamma))*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [156]:
c_xs, uc_xs = plot_opinions(5)

In [157]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=c_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)), row=1, col=1)
fig.add_trace(go.Scatter(x=uc_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8), row=1, col=1)
fig.update_layout(
    title="Opinion distribution",
    xaxis_title="X",
    yaxis_title="Num",
    xaxis_range=[-1.2,1.2],
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

In [158]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

In [159]:
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'surface'}]])
xs = np.linspace(-1.1, 1.1, 100)
ys = np.linspace(0, 5, 100)
us_pred = np.array([[solver.u(*([x] + [x for _ in range(eq.var_dim-2)]), y).item() for x in xs] for y in ys])
#us_pred = np.array([[sum([np.abs(x+0.25) for _ in range(10)]).item() for x in xs] for y in ys])
us = us_pred #np.array([[sol(x, t) for x in xs] for t in ts])
x_mesh, y_mesh = np.meshgrid(xs, ys)
fig.add_trace(go.Surface(x=xs, y=ys, z=us, showscale=False), row=1, col=1)
#fig.add_trace(go.Surface(x=ys, y=xs, z=us_pred), row=1, col=2)
fig.update_layout(title='Solution | Approximation',
                  scene = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="u(x, t)"),
                  scene2 = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="u(x, t)"),
                  margin=dict(l=50, r=50, b=50, t=50))
# one row is printed as x axes
fig.show()

# -------

In [None]:
MODEL_CONFIG = {
    "batch_size": 128,
    "num_discretisation_steps": 50,
    "hidden_dim": 128,
    "learning_rate": 5e-3,
    "lr_decay": 0.99,
    "network_type": "MINI",
    "optimiser": "Adam"
}
model = MODEL_CONFIG
class OPINION(FBSDE):
    def __init__(self):
        super().__init__()
        N = 10
        x_d = 0.5
        self.beta = -1
        self.gamma = 0.05
        sigma = np.sqrt(2*0.3)*torch.ones((model["batch_size"], 1, 1))
        C = lambda X: torch.sum((X - x_d)**2, dim=-1, keepdims=True)/(2*N)
        self.h = lambda X, Y, Z, t: C(X) + (N/(2*self.gamma))* torch.sum(Z**2, dim=-1, keepdims=True)
        
        self.b = lambda X, t: self.beta*(1 - X**2)*(torch.sum(X, dim=-1, keepdims=True)/N - X)
    
        self.sigma = lambda X, t: sigma
        
        self.terminal_condition = lambda X: torch.zeros((model["batch_size"], 1))
        
        self.var_dim = N
        self.terminal_time = 3     
        self.init_sampling_func = lambda X: (X-0.5)*2
        self.control_noise = 0.1

In [None]:
eq_2 = OPINION()
solver = FBSDESolver(model, eq_2)
loss = np.array(list(solver.train(500)))
plot_losses(loss)

In [None]:
def plot_opinions(T):
    n = int(50*T)
    N = eq_2.var_dim
    def sde_sum(X):
        return eq.beta*(1-X**2)*(sum(X)/N-X)
    c_xs = np.zeros((N, n))
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.8).clip(-1, 1)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.1 +0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(N/eq.gamma)*np.expand_dims(solver.D_J(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [None]:
c_xs, uc_xs = plot_opinions(3)

In [None]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=c_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)), row=1, col=1)
fig.add_trace(go.Scatter(x=uc_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8), row=1, col=1)
fig.update_layout(
    title="Opinion distribution",
    xaxis_title="X",
    yaxis_title="Num",
    xaxis_range=[-1.2,1.2],
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()