In [None]:
import numpy as np
import torch
from torch import nn
import csv
import copy
import pandas as pd
import numpy as np
import random
import time
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
from utils import get_rho_from_u, plot_3d, plot_diff, get_rho_network_from_u, get_rho_network_from_actor, train_actor_from_u, train_rho_network_n_step
from model import Critic, RhoNetwork

In [None]:
n_cell = 8
T_terminal = 1
option = 'non-sep'
d = np.loadtxt(f"data/rho-{option}.txt")[:, 0].flatten('F')

In [None]:
delta_T = 1 / n_cell
T = int(T_terminal / delta_T)
u_hist = [.5 * np.ones((n_cell, T))]
rho_hist = [get_rho_from_u(u_hist[0], d)]
u_loss_hist, rho_loss_hist = list(), list()
u_gap_hist, rho_gap_hist = list(), list()
u_res = np.loadtxt(f"data/u-{option}.txt")
rho_res = np.loadtxt(f"data/rho-{option}.txt")

rho = get_rho_from_u(u_hist[0], d)

rho_network = RhoNetwork(2)
rho_optimizer = torch.optim.Adam(rho_network.parameters(), lr=1e-3)
rho_network = train_rho_network_n_step(n_cell, T_terminal, rho, rho_network, rho_optimizer, n_iterations=1)

fake_critic = Critic(2) #train_critic_fake(n_cell, T_terminal, np.zeros((n_cell + 1, T + 1)))
critic = Critic(2)
critic_optimizer = torch.optim.Adam(critic.parameters(), lr=1e-3)

In [None]:
init_time = time.time()
for it in range(300):
    states = list()
    truths = list()
    u = np.ones((n_cell, T))
    for i in range(n_cell + 1):
        for t in range(T + 1):
            states.append(np.array([i, t]) / n_cell)
            if t == T:
                truths.append(0)
            else:
                if i == n_cell:
                    rho_i_t = float(rho_network.forward(np.array([0, t]) / n_cell))
                    speed = min(max(float(critic(np.array([i, t + 1]) / n_cell) - critic(np.array([0, t + 1]) / n_cell)) / delta_T + 1 - rho_i_t,0),1)
                    truths.append(delta_T * (0.5 * speed ** 2 + rho_i_t *speed  - speed) + fake_critic(
                        np.array([speed, t + 1]) / n_cell))
                else:
                    rho_i_t = float(rho_network.forward(np.array([i, t]) / n_cell))
                    speed = min(max(float(critic(np.array([i, t + 1]) / n_cell) - critic(np.array([i + 1, t + 1]) / n_cell)) / delta_T + 1 - rho_i_t,0),1)
                    u[i, t] = speed
                    truths.append(delta_T * (0.5 * speed ** 2 + rho_i_t*speed  - speed) + fake_critic(
                        np.array([i + speed, t + 1]) / n_cell))
                    

    truths = torch.tensor(truths, requires_grad=True)
    for c_it in range(100):
        preds = torch.reshape(critic(np.array(states)), (1, len(truths)))
        critic_loss = (truths - preds).abs().mean()
        critic_optimizer.zero_grad()
        critic_loss.backward()
        critic_optimizer.step()

    fake_critic = critic

    u_hist.append(u)
    u = np.array(u_hist).mean(axis=0)
    rho = get_rho_from_u(u, d)
    rho_hist.append(rho)
    u_loss_hist.append(np.mean(abs(u - u_res)))
    rho_loss_hist.append(np.mean(abs(rho - rho_res)))
    u_gap_hist.append(np.mean(abs(u_hist[-1] - u_hist[-2])))
    rho_gap_hist.append(np.mean(abs(rho_hist[-1] - rho_hist[-2])))

    rho_network = get_rho_network_from_u(n_cell, T_terminal, u, d, rho_network, rho_optimizer, n_iterations=100)
    if it % 20 == 0 and it > 0:
        plot_3d(n_cell, T_terminal, u, "u", f"./fig/u/{it}.pdf")
        plot_3d(n_cell, T_terminal, rho, r"$\rho$", f"./fig/rho/{it}.pdf")

print(time.time() - init_time)
pd.DataFrame(u_gap_hist).to_csv(f"./diff/u-gap-{option}.csv")
pd.DataFrame(rho_gap_hist).to_csv(f"./diff/rho-gap-{option}.csv")
pd.DataFrame(u_loss_hist).to_csv(f"./diff/u-loss-{option}.csv")
pd.DataFrame(rho_loss_hist).to_csv(f"./diff/rho-loss-{option}.csv")