# Solving Mean Field Games (Deep Galerkin)
#### Frederik Kelbel, Imperial College London

In [None]:
import torch
import plotly.graph_objects as go
import numpy as np
from operators import div, Δ, D, mdotb, bdotm, mdotm, bdotb, m, p, H, cat
from DGM import DGMPIASolver, DeepPDESolver
from pdes import HBJ, PDE
from scipy.integrate import quad
from plotly.subplots import make_subplots
from configs import CONFIG_HBJS as MODEL_CONFIG
from FBSDEs import FBSDESolver
from pdes import FBSDE
from sampling import PATH_SPACES
import torch.distributions as dist
import os
from plotly.offline import init_notebook_mode
init_notebook_mode()
torch.manual_seed(0)
np.random.seed(0)

In [None]:
def plot_losses(losses, avg_over=10):
    avgs = np.convolve(losses, np.ones(avg_over), 'valid') / avg_over
    fig = make_subplots(rows=1, cols=1)
    fig.add_trace(go.Scatter(x=np.arange(len(avgs)), y=avgs, mode='lines', name="Error at x=0.1"), row=1, col=1)
    fig.update_layout(
        title="Loss",
        xaxis_title="Iterations",
        yaxis_title="Loss",
        font=dict(
            family="Courier New, monospace",
            size=14
        )
    )
    fig.show()

## Linear-Quadratic-Control

Applied Hopf-Cole Transform before solving!!!
Networks were not able to solve problem without transform.
Note: Can we apply other transformations to improve performance.
Note: Can we somehow utilize the Fokker-Planck Equation, i.e. sample from an equation that satisfies it??

In [None]:
class LQC(PDE):
    def __init__(self):
        super().__init__()
        self.l = 5.0
        self.sigma = 0.01
        self.gamma = 0.05
        N = 5
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1
            
        self.equation = lambda u, var: div(u, var[-1]) + (1/2)*self.sigma*Δ(u, var[:-1]) - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        self.domain_func = [(lambda var: self.sampling(var[:-1]) + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u - (self.l*torch.sum((sum(var[:-1])/N-cat(var[:-1]))**2, dim=-1, keepdims=True)/2)] 
        self.boundary_func = [(lambda var: self.sampling(var[:-1]) + [0*var[-1]+5], 128)]
    
    def sampling(self, var):
        means = (torch.rand((var[-1].shape[0], 1))-0.5)*2
        return [(var[i]-0.5) + means for i in range(len(var))]
    '''
    def sampling(self, var):
        means_clusters = torch.randint(0, self.var_dim-1, size=(var[-1].shape[0], self.var_dim-1))
        compactness_clusters = torch.randint(0, self.var_dim-1, size=(var[-1].shape[0], self.var_dim-1))
        means = (torch.rand((var[-1].shape[0], self.var_dim-1))-0.5)*2
        compactness = torch.rand((var[-1].shape[0], self.var_dim-1))
        re = [(var[i]-0.5)*compactness.gather(-1, compactness_clusters[:, None, i]) + means.gather(-1, means_clusters[:, None, i]) for i in range(len(var))] 
        return re 
    '''

In [None]:
PATH_SPACES["LQC"] = {
        "SDE": lambda X, u, t, dt, dW: u*dt + np.sqrt(eq.sigma)*dW,
        "terminal_time": 5.0,
        "N_range": (100, 101),
        "control": lambda J, X, t: -(1/eq.gamma)*D(J, X)
}
eq = LQC()
model = MODEL_CONFIG = {
    "hidden_dim": 64,
    "learning_rate": 1e-3,
    "loss_weights": (1, 2),
    "sampling_method": "uniform",
    "sampling_method_boundary": "uniform",
    "lr_decay": 0.99,
    "network_type": "RES",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(600))
plot_losses(losses)

In [None]:
def plot_game(T):
    n = int(40*T)
    N = eq.var_dim-1
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.8).clip(-0.7, 0.9)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.1 +0.5).clip(0.4, 0.6)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/eq.gamma)*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i]  + np.sqrt(eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + np.sqrt(eq.sigma)*dW
        
    return c_xs, uc_xs

In [None]:
c_xs, uc_xs = plot_game(5)

In [None]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

## Sznajd Model

We made the substitution of $J = \frac{\phi}{N}$

In [None]:
class OPINION(PDE):
    def __init__(self):
        super().__init__()
        self.beta = -3
        self.sigma = 0.01
        self.gamma = 0.04
        self.l = 1.0
        N = 10
        self.x_d = torch.ones((1, N)).cuda()*0.2
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1       
            
        self.equation = lambda u, var: div(u, var[-1]) + self.l*torch.sum((cat(var[:-1]) - self.x_d)**2, dim=-1, keepdims=True)/2 \
                        + torch.sum((self.beta*(1 - cat(var[:-1])**2)
                                     *(torch.mean(cat(var[:-1]), dim=-1, keepdims=True) - cat(var[:-1]))
                                     * D(u, var[:-1])), dim=-1, keepdims=True) + self.sigma*Δ(u, var[:-1]) \
                        - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        self.domain_func = [(lambda var: self.sampling(var[:-1]) + [5*var[-1]], 256)]
        self.boundary_cond = [lambda u, var: u -self.l*torch.sum((cat(var[:-1]) - self.x_d)**2, dim=-1, keepdims=True)/2] 
        self.boundary_func = [(lambda var: self.term_sampling(var[:-1]) + [0*var[-1]+5], 64)]
        
    def sampling(self, var, domain=(-0.99, 0.99)):
        exp = dist.Gamma(3, 2)
        means_clusters = torch.clamp(torch.round(exp.sample((var[-1].shape[0], self.var_dim-1))), min=0, max=self.var_dim-2).long()
        means = domain[0] + (domain[1]-domain[0])*(torch.rand((var[-1].shape[0], self.var_dim-1)))
        samples = []
        for i in range(len(var)):
            m = means.gather(-1, means_clusters[:, None, i])
            b = torch.min(torch.abs(m-domain[0]), torch.abs(m-domain[1]))
            c = torch.rand((var[-1].shape[0], 1))
            samples.append((var[i]-0.5)*torch.min(c, b) + m)
        return samples
    
    def term_sampling(self, var):
        re = self.sampling(var)
        for i, v in enumerate(re):
            v[0] = self.x_d[:, i]
        return re

In [None]:
PATH_SPACES["OPINION"] = {
        "SDE": lambda X, u, t, dt, dW: eq.beta*(1 - X**2)*(torch.mean(X, dim=-1, keepdims=True) - X)*dt + u*dt + np.sqrt(2*eq.sigma)*dW,
        "terminal_time": 5.0,
        "N_range": 100,
        "control": lambda J, X, t: -(1/eq.gamma)*D(J, X),
        "domain": (-1, 1)
    }
eq = OPINION()
model = MODEL_CONFIG = {
    "hidden_dim": 128, # We roughly have to double the hidden dimensions every time we double N--> starting with N=5, h=32
    "learning_rate": 1e-3,
    "loss_weights": (1, 5),
    "sampling_method": "path",
    "sampling_method_boundary": "uniform",
    "lr_decay": 0.99,
    "network_type": "RES",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(1700)) # We also add 100 extra training iterations every time we double N, starting: N=5, it=200
plot_losses(losses)

In [None]:
def plot_opinions(T):
    n = 100
    N = eq.var_dim-1
    def sde_sum(X):
        return eq.beta*(1-X**2)*(sum(X)/N-X)
    c_xs = np.zeros((N, n))
    c_cost = np.zeros(n)
    uc_cost = np.zeros(n)
    alpha_cost = np.zeros(n)
    #c_xs[:, 0] = -0.5
    #c_xs[:, 0] = np.array([n.flatten().numpy() for n in eq.sampling([torch.zeros(size=(1, 1)).uniform_() for _ in range(N)])]).flatten()
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.1 -0.7).clip(-0.8, -0.6)
    c_xs[int(2*N/5):int(3*N/5), 0] = (np.random.randn(int(1*N/5))*0.2 +0.75).clip(0.6, 0.9)
    c_xs[int(3*N/5):, 0] = (np.random.randn(int(2*N/5))*0.2 -0.25).clip(-0.4, -0.1)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    alpha_xs = np.zeros((N, n))
    alpha_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    x_d = eq.x_d.cpu().view(-1, 1).numpy()
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/eq.gamma)*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        alpha_c = 6*(x_d - alpha_xs[:, None, i])
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        alpha_xs[:, None, i+1] = alpha_xs[:, None, i] + sde_sum(alpha_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + alpha_c*dt
        c_cost[i+1] = c_cost[i] + (1/(2*N))*(np.sum((c_xs[:, None, i] - x_d)**2) + eq.gamma*np.sum(c**2))
        uc_cost[i+1] = uc_cost[i] + (1/(2*N))*(np.sum((uc_xs[:, None, i] - x_d)**2))
        alpha_cost[i+1] = alpha_cost[i] + (1/(2*N))*(np.sum((alpha_xs[:, None, i] - x_d)**2) + eq.gamma*np.sum(alpha_c**2))
    c_cost[-1] = c_cost[-2] + (1/N)*(np.sum((c_xs[:, None, -1] - x_d)**2))
    uc_cost[-1] = uc_cost[-2] + (1/N)*(np.sum((uc_xs[:, None, -1] - x_d)**2))
    alpha_cost[-1] = alpha_cost[-2] + (1/N)*(np.sum((alpha_xs[:, None, -1] - x_d)**2))
    
        
    return c_xs, uc_xs, alpha_xs, c_cost, uc_cost, alpha_cost

In [None]:
c_xs, uc_xs, alpha_xs, c_cost, uc_cost, alpha_cost = plot_opinions(5)

In [None]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=c_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)), row=1, col=1)
fig.add_trace(go.Scatter(x=uc_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8), row=1, col=1)
fig.update_layout(
    title="Opinion distribution",
    xaxis_title="X",
    yaxis_title="Num",
    xaxis_range=[-1.2,1.2],
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

In [None]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, 100)
for path in c_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(c_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, 100)
for path in uc_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(uc_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, 100)
for path in alpha_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(alpha_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, 100)
fig.add_trace(go.Scatter(x=ts, y=c_cost, mode='lines', name="controlled"), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=uc_cost, mode='lines', name="uncontrolled"), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=alpha_cost, mode='lines', name="alpha-controlled"), row=1, col=1)
fig.update_layout(title='Cost',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
caten = torch.vstack(solver.domain_sampler.app)
fig = make_subplots(rows=1, cols=1)
for i in range(10):
    fig.add_trace(go.Scatter(y=caten[:, i], mode='lines', showlegend=False), row=1, col=1)
fig.update_layout(title='Cost',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

### Also plot alpha and weighted_ws

## Hegselmann-Krause

In [None]:
class BOUNDED_OPINION(PDE):
    def __init__(self):
        super().__init__()
        self.sigma = 0.01
        self.gamma = 0.05
        self.l = 1.0
        self.kappa = 0.3
        self.beta = -12
        N = 10
        self.x_d = torch.ones((1, N)).cuda()*0.2
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1    
            
        self.equation = lambda u, var: div(u, var[-1]) + self.l*torch.sum((cat(var[:-1]) - self.x_d)**2, dim=-1, keepdims=True)/2 \
                        + torch.sum((self.beta/N)*self.P(cat(var[:-1]))* D(u, var[:-1]), dim=-1, keepdims=True)\
                        + self.sigma*Δ(u, var[:-1]) \
                        - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        self.domain_func = [(lambda var: self.sampling(var[:-1]) + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u -self.l*torch.sum((cat(var[:-1]) - self.x_d)**2, dim=-1, keepdims=True)/2] 
        self.boundary_func = [(lambda var: self.term_sampling(var[:-1]) + [0.0*var[-1]+5], 64)]
        
    def P(self, X):
        one = torch.ones_like(X).to(X.device)
        matrix = torch.einsum("bi, bj -> bij", X, one) - torch.einsum("bi, bj -> bji", X, one)
        return torch.sum((torch.abs(matrix) <= self.kappa)*matrix, dim=-1)
    
    
    def sampling(self, var, domain=(-0.99, 0.99)):
        exp = dist.Gamma(3, 2)
        means_clusters = torch.clamp(torch.round(exp.sample((var[-1].shape[0], self.var_dim-1))), min=0, max=self.var_dim-2).long()
        means = domain[0] + (domain[1]-domain[0])*(torch.rand((var[-1].shape[0], self.var_dim-1)))
        samples = []
        for i in range(len(var)):
            m = means.gather(-1, means_clusters[:, None, i])
            b = torch.min(torch.abs(m-domain[0]), torch.abs(m-domain[1]))
            c = torch.rand((var[-1].shape[0], 1))
            samples.append((var[i]-0.5)*torch.min(c, b) + m)
        return samples
    
    def term_sampling(self, var):
        re = self.sampling(var)
        for i, v in enumerate(re):
            v[0] = self.x_d[:, i]
        return re

In [None]:
PATH_SPACES["BOUNDED_OPINION"] = {
        "SDE": lambda X, u, t, dt, dW: (eq.beta/(eq.var_dim-1))*eq.P(X)*dt + u*dt + np.sqrt(2*eq.sigma)*dW,
        "terminal_time": 5.0,
        "N_range": 100,
        "control": lambda J, X, t: -(1/eq.gamma)*D(J, X),
        "domain": (-1, 1)
}
eq = BOUNDED_OPINION()
model = MODEL_CONFIG = {
    "hidden_dim": 128,
    "learning_rate": 1e-3,
    "loss_weights": (1, 5),
    "sampling_method": "path",
    "sampling_method_boundary": "uniform",
    "lr_decay": 0.99,
    "network_type": "RES",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(500))
plot_losses(losses)

In [None]:
def plot_bounded_opinions(T):
    n = 100
    N = eq.var_dim-1
    def sde_sum(X):
        return (eq.beta/N)*eq.P(torch.from_numpy(X).view(1, -1)).numpy().reshape(-1, 1)
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.5 -0.6).clip(-0.95, -0.4)
    c_xs[int(2*N/5):int(3*N/5), 0] = (np.random.randn(int(1*N/5))*0.5 +0.7).clip(0.4, 0.95)
    c_xs[int(3*N/5):, 0] = (np.random.randn(int(2*N/5))*0.5 +0.1).clip(-0.15, 0.45)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/(eq.gamma))*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [None]:
c_xs, uc_xs = plot_bounded_opinions(5)

In [None]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.6, 1.6], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, int(50*5))
for path in c_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(c_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, int(50*5))
for path in uc_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(uc_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

# -----

In [None]:
class HOPF_OPINION(PDE):
    def __init__(self):
        super().__init__()
        self.x_d = 0.0
        self.beta = -1
        self.sigma = 0.01
        self.gamma = 0.05
        N = 5
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1       
            
        self.equation = lambda u, var: div(u, var[-1]) + torch.sum((self.beta*(1 - cat(var[:-1])**2)
                                     *(torch.mean(cat(var[:-1]), dim=-1, keepdims=True) - cat(var[:-1]))
                                     * D(u, var[:-1])), dim=-1, keepdims=True)\
                        + self.sigma*Δ(u, var[:-1])# - u*torch.sum((cat(var[:-1]) - self.x_d)**2, dim=-1, keepdims=True)/(4*self.gamma*self.sigma) \
        self.domain_func = [(lambda var: self.sampling(var[:-1]) + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u] 
        self.boundary_func = [(lambda var: self.sampling(var[:-1]) + [0*var[-1]+5], 64)]
    
    def sampling(self, var):
        means = 0 #(torch.rand((var[-1].shape[0], 1))-0.5)*2
        return [(var[i]-0.5)*2 + means for i in range(len(var))]  
    
    '''def sampling(self, var):
        means_clusters = torch.randint(0, self.var_dim-1, size=(var[-1].shape[0], self.var_dim-1))
        compactness_clusters = torch.randint(0, self.var_dim-1, size=(var[-1].shape[0], self.var_dim-1))
        means = (torch.rand((var[-1].shape[0], self.var_dim-1))-0.5)*2 + self.x_d
        compactness = torch.rand((var[-1].shape[0], self.var_dim-1))
        re = [(var[i]-0.5)*compactness.gather(-1, compactness_clusters[:, None, i]) + means.gather(-1, means_clusters[:, None, i]) for i in range(len(var))] 
        return re'''

In [None]:
eq = HOPF_OPINION()
model = MODEL_CONFIG = {
    "hidden_dim": 32,
    "learning_rate": 1e-3,
    "loss_weights": (1, 0),
    "sampling_method": "uniform",
    "lr_decay": 0.98,
    "network_type": "RES",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(300))
plot_losses(losses)

In [None]:
def plot_opinions(T):
    n = int(50*T)
    N = eq.var_dim-1
    def sde_sum(X):
        return eq.beta*(1-X**2)*(sum(X)/N-X)
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.2 -0.8).clip(-0.7, 0.9)
    c_xs[int(2*N/5):, 0] = (np.random.randn(int(3*N/5))*0.2 +0.5).clip(0.4, 0.6)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = 2*eq.sigma*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)/solver.u(*[c_xs[j, i] for j in range(N)], i*dt)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [None]:
c_xs, uc_xs = plot_opinions(5)

In [None]:
fig = make_subplots(rows=1, cols=1)
fig.add_trace(go.Scatter(x=c_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)), row=1, col=1)
fig.add_trace(go.Scatter(x=uc_xs[:, -1].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8), row=1, col=1)
fig.update_layout(
    title="Opinion distribution",
    xaxis_title="X",
    yaxis_title="Num",
    xaxis_range=[-1.2,1.2],
    font=dict(
        family="Courier New, monospace",
        size=14
    )
)
fig.show()

In [None]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.1, 1.1], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1, specs=[[{'type': 'surface'}]])
xs = np.linspace(-1.1, 1.1, 100)
ys = np.linspace(0, 5, 100)
us_pred = np.array([[solver.u(*([-0.7, -0.7, -0.7, -0.7] + [x for _ in range(eq.var_dim-5)]), y).item() for x in xs] for y in ys])
#us_pred = np.array([[sum([np.abs(x+0.25) for _ in range(10)]).item() for x in xs] for y in ys])
us = us_pred #np.array([[sol(x, t) for x in xs] for t in ts])
x_mesh, y_mesh = np.meshgrid(xs, ys)
fig.add_trace(go.Surface(x=xs, y=ys, z=us, showscale=False), row=1, col=1)
#fig.add_trace(go.Surface(x=ys, y=xs, z=us_pred), row=1, col=2)
fig.update_layout(title='Solution | Approximation',
                  scene = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="u(x, t)"),
                  scene2 = dict(
                    xaxis_title="x",
                    yaxis_title="t",
                    zaxis_title="u(x, t)"),
                  margin=dict(l=50, r=50, b=50, t=50))
# one row is printed as x axes
fig.show()

## Hegselmann-Krause (Wasserstein-distance)

In [None]:
class BOUNDED_OPINION(PDE):
    def __init__(self):
        super().__init__()
        self.sigma = 0.01
        self.gamma = 0.05
        self.kappa = 0.01
        self.l = 2.0
        self.beta = -10
        N = 10
        self.x_d = torch.sort((torch.randn((1, N))*0.2).clamp(-1, 1) + 0.3, dim=-1)[0]
        print("Target distribution: N({0:7.5f}, {1:7.5f})".format(np.mean(self.x_d.numpy().flatten()), np.std(self.x_d.numpy().flatten())))
        self.x_d_cuda = self.x_d.to(torch.device("cuda"))
        self.var_dim = N+1 # var = (x, t)
        self.sol_dim = 1    
            
        self.equation = lambda u, var: div(u, var[-1]) + self.p_ws_dist(cat(var[:-1]))/2 \
                        + torch.sum((self.beta/N)*self.P(cat(var[:-1]))* D(u, var[:-1]), dim=-1, keepdims=True)\
                        + self.sigma*Δ(u, var[:-1]) \
                        - (1/(2*self.gamma))*torch.sum(D(u, var[:-1])**2, dim=-1, keepdims=True)
        self.domain_func = [(lambda var: self.sampling(var[:-1]) + [5*var[-1]], 128)]
        self.boundary_cond = [lambda u, var: u -self.p_ws_dist(cat(var[:-1]))/2] 
        self.boundary_func = [(lambda var: self.term_sampling(var[:-1]) + [0.0*var[-1]+5], 64)]
        
    def P(self, X):
        one = torch.ones_like(X).to(X.device)
        matrix = torch.einsum("bi, bj -> bij", X, one) - torch.einsum("bi, bj -> bji", X, one)
        return torch.sum((torch.abs(matrix) <= self.kappa)*matrix, dim=-1)
    
    def p_ws_dist(self, X, p=2):
        X, X_idxs = torch.sort(X, dim=-1)
        return self.l*torch.mean(torch.abs(X - self.x_d_cuda)**p, dim=-1, keepdims=True) # **(1/p)
    
    def sampling(self, var, domain=(-0.99, 0.99)):
        exp = dist.Gamma(3, 2)
        means_clusters = torch.clamp(torch.round(exp.sample((var[-1].shape[0], self.var_dim-1))), min=0, max=self.var_dim-2).long()
        means = domain[0] + (domain[1]-domain[0])*(torch.rand((var[-1].shape[0], self.var_dim-1)))
        samples = []
        for i in range(len(var)):
            m = means.gather(-1, means_clusters[:, None, i])
            b = torch.min(torch.abs(m-domain[0]), torch.abs(m-domain[1]))
            c = torch.rand((var[-1].shape[0], 1))
            samples.append((var[i]-0.5)*torch.min(c, b) + m)
        return samples
    
    def term_sampling(self, var):
        re = self.sampling(var)
        for i, v in enumerate(re):
            v[0] = self.x_d[:, i]
        return re
eq = BOUNDED_OPINION()

In [None]:
PATH_SPACES["BOUNDED_OPINION"] = {
        "SDE": lambda X, u, t, dt, dW: (eq.beta/(eq.var_dim-1))*eq.P(X)*dt + u*dt + np.sqrt(2*eq.sigma)*dW,
        "terminal_time": 5.0,
        "N_range": 100,
        "control": lambda J, X, t: -(1/eq.gamma)*D(J, X),
        "domain": (-1, 1)
}
model = MODEL_CONFIG = {
    "hidden_dim": 32,
    "learning_rate": 1e-3,
    "loss_weights": (1, 5),
    "sampling_method": "path",
    "sampling_method_boundary": "uniform",
    "lr_decay": 0.99,
    "network_type": "RES",
    "optimiser": "Adam",
    "method": "Galerkin"
}
solver = DeepPDESolver(model, eq)
losses = list(solver.train(1500))
plot_losses(losses)

In [None]:
def plot_bounded_opinions(T):
    n = int(50*T)
    N = eq.var_dim-1
    def sde_sum(X):
        return (eq.beta/N)*eq.P(torch.from_numpy(X).view(1, -1)).numpy().reshape(-1, 1)
    c_xs = np.zeros((N, n))
    #c_xs[:, 0] = -0.5
    c_xs[:int(2*N/5), 0] = (np.random.randn(int(2*N/5))*0.5 -0.6).clip(-0.95, -0.4)
    c_xs[int(2*N/5):int(3*N/5), 0] = (np.random.randn(int(1*N/5))*0.5 +0.7).clip(0.4, 0.95)
    c_xs[int(3*N/5):, 0] = (np.random.randn(int(2*N/5))*0.5 +0.1).clip(-0.15, 0.45)
    #c_xs[:, 0] = (np.random.randn(N)-0.5).clip(-1, 1)
    uc_xs = np.zeros((N, n))
    uc_xs[:, 0] = c_xs[:, 0]
    dt = 1/n
    ts = [t for t in np.linspace(0, T, n)]
    for i in range(n-1):
        dW = np.sqrt(dt)*np.random.randn(N, 1)
        c = -(1/(eq.gamma))*np.expand_dims(solver.D_u(*[c_xs[j, i] for j in range(N)], i*dt), axis=1)
        c_xs[:, None, i+1] = c_xs[:, None, i] + sde_sum(c_xs[:, None, i])*dt + np.sqrt(2*eq.sigma)*dW + c*dt
        uc_xs[:, None, i+1] = uc_xs[:, None, i] + sde_sum(uc_xs[:, None, i]) *dt + np.sqrt(2*eq.sigma)*dW
        
    return c_xs, uc_xs

In [None]:
c_xs, uc_xs = plot_bounded_opinions(5)

In [None]:
fig = go.Figure(
    data=[go.Scatter(x=c_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
         go.Scatter(x=uc_xs[:, 0].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)],
    layout=go.Layout(
        xaxis=dict(range=[-1.6, 1.6], autorange=False),
        yaxis=dict(range=[0, 2], autorange=False),
        title="Opinions over time",
        updatemenus=[dict(
            type="buttons",
            buttons=[dict(label="Play",
                          method="animate",
                          args=[None, {"frame": {"duration": 25, 
                                                                        "redraw": False},
                                                              "fromcurrent": True, 
                                                              "transition": {"duration": 0.1}}])])]
    ),
    frames=[go.Frame(data=[go.Scatter(x=c_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Controlled", marker=dict(color='SkyBlue', size=15)),
                          go.Scatter(x=uc_xs[:, i].flatten(), y=np.ones(eq.var_dim-1),
                         mode='markers',
                         name="Uncontrolled", opacity=0.8)]) for i in range(c_xs.shape[1])]
)
fig.update_layout(transition = {'duration': 0.1})
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, int(50*5))
for path in c_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(c_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=1)
ts = np.linspace(0, 5, int(50*5))
for path in uc_xs:
    fig.add_trace(go.Scatter(x=ts, y=path, mode='lines', showlegend=False), row=1, col=1)
fig.add_trace(go.Scatter(x=ts, y=np.mean(uc_xs, axis=0), mode='lines', showlegend=False, line=dict(color="blue", width=2)), row=1, col=1)
fig.update_layout(title='Solution | Approximation',
                  xaxis_title="t",
                  yaxis_title="x", 
                  margin=dict(l=50, r=50, b=50, t=50))
fig.show()