In [None]:
import os, random, imageio, shutil
from collections import deque

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
from tqdm import trange

# -----------------------
# 1) ENVIRONMENT
# -----------------------
class InventoryEnv:
    def __init__(self, min_inv=-50, max_inv=50, T=50):
        self.min_inv = min_inv
        self.max_inv = max_inv
        self.T       = T
        self.reset()

    def reset(self):
        self.s = random.randint(self.min_inv, self.max_inv)
        self.t = 0
        return np.array([self.s, self.t], dtype=np.float32)

    def step(self, a):
        # w ~ uniform{0..10}
        w  = random.randint(0, 10)
        s2 = self.s + a - w
        s2 = max(self.min_inv, min(self.max_inv, s2))
        # cost
        p, a_cost, b_cost = 1, 1, 2
        order_cost = p * a
        inv_cost   = a_cost * s2 if s2 >= 0 else b_cost * (-s2)
        cost       = order_cost + inv_cost
        # reward = negative cost
        r = -cost
        # time
        self.t += 1
        done = (self.t >= self.T)
        self.s = s2
        obs2 = np.array([self.s, self.t if not done else 0], dtype=np.float32)
        return obs2, r, done

# -----------------------
# 2) A2C MODEL
# -----------------------
class A2C(nn.Module):
    def __init__(self, state_dim, num_actions):
        super().__init__()
        self.shared = nn.Sequential(
            nn.Linear(state_dim, 64), nn.ReLU(),
            nn.Linear(64, 64),        nn.ReLU(),
        )
        self.actor  = nn.Linear(64, num_actions)
        self.critic = nn.Linear(64, 1)

    def forward(self, x):
        h = self.shared(x)
        return self.actor(h), self.critic(h)

    def get_action(self, state):
        logits, value = self.forward(state)
        dist  = torch.distributions.Categorical(logits=logits)
        a     = dist.sample()
        logp  = dist.log_prob(a)
        return a.item(), logp, value.squeeze()

# -----------------------
# 3) HYPERPARAMETERS
# -----------------------
num_actions   = 10
T             = 50
min_inv, max_inv = -50, 50
state_dim     = 2
gamma         = 1.0
lr            = 1e-3
episodes      = 50000
critic_coef   = 0.5
entropy_coef  = 0.01

# -----------------------
# 4) SETUP
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env    = InventoryEnv(min_inv, max_inv, T)
agent  = A2C(state_dim, num_actions).to(device)
opt    = optim.Adam(agent.parameters(), lr=lr)

# for convergence monitoring
adv_history = []

# -----------------------
# 5) TRAIN
# -----------------------
for ep in trange(episodes, desc="Training"):
    # storage
    logps, values, rewards = [], [], []
    # init
    obs = env.reset()
    for t in range(T):
        # normalize state: s/max_inv∈[–1,1], t/(T-1)∈[0,1]
        st = torch.tensor([obs[0]/max_inv, obs[1]/(T-1)],
                          dtype=torch.float32, device=device)
        a, logp, v = agent.get_action(st)
        obs2, r, done = env.step(a)
        logps .append(logp)
        values.append(v)
        rewards.append(r)
        obs = obs2
        if done: break

    # compute returns & advantages
    returns = []
    G = 0.0
    for r in reversed(rewards):
        G = r + gamma * G
        returns.insert(0, G)
    returns = torch.tensor(returns, dtype=torch.float32, device=device)
    values  = torch.stack(values)
    logps   = torch.stack(logps)
    advs    = returns - values
    adv_history.append(advs.abs().mean().item())

    # losses
    actor_loss  = -(logps * advs.detach()).mean()
    critic_loss = advs.pow(2).mean()
    # entropy bonus
    # (optional) encourage exploration
    # ent = torch.distributions.Categorical(logits=agent.forward(st)[0]).entropy().mean()
    # loss = actor_loss + critic_coef*critic_loss - entropy_coef*ent
    loss = actor_loss + critic_coef*critic_loss

    # backward
    opt.zero_grad()
    loss.backward()
    opt.step()

# -----------------------
# 6) PLOT ADVANTAGE CONVERGENCE
# -----------------------
plt.figure(figsize=(6,4))
plt.plot(adv_history)
plt.xlabel("Episode")
plt.ylabel("Mean |Advantage|")
plt.title("Critic Convergence (Advantage → 0)")
plt.grid()
plt.tight_layout()
plt.show()

# -----------------------
# 7) EXTRACT POLICY & V(s,t)
# -----------------------
agent.eval()
num_states = max_inv - min_inv + 1
pi_tab     = np.zeros((T, num_states), dtype=int)
V_tab      = np.zeros((T, num_states), dtype=float)

with torch.no_grad():
    for t in range(T):
        for i, s in enumerate(range(min_inv, max_inv+1)):
            st = torch.tensor([[s/max_inv, t/(T-1)]],
                              dtype=torch.float32, device=device)
            logits, v = agent.forward(st)
            pi_tab[t,i] = int(logits.argmin(dim=-1).item())
            V_tab [t,i] = v.item()

# -----------------------
# 8) COMPUTE BASE‐STOCK σ_t
# -----------------------
sigma = []
for t in range(T):
    th = max_inv
    for i,a in enumerate(pi_tab[t]):
        s = i + min_inv
        if a == 0:
            th = s
            break
    sigma.append(th)

print("Period | σ_t")
print("---------------")
for t,z in enumerate(sigma):
    print(f"{t:6d} | {z:4d}")

# -----------------------
# 9) ROLLOUT DEMO
# -----------------------
print("\nTime | Inv | σ_t | Demand | Action")
print("------------------------------------")
s = 0
for t in range(T):
    w = random.randint(0,10)
    a = pi_tab[t, s-min_inv]
    print(f"{t:4d} | {s:4d} | {sigma[t]:4d} | {w:6d} | {a:6d}")
    s = max(min_inv, min(max_inv, s + a - w))

# -----------------------
# 10) GIF OF V(s) EVOLUTION
# -----------------------
os.makedirs("frames", exist_ok=True)
states = np.arange(min_inv, max_inv+1)
for t in range(T):
    plt.figure(figsize=(4,3))
    plt.plot(states, V_tab[t], marker='o')
    plt.title(f"V(s) at t={t}")
    plt.xlabel("Inventory")
    plt.ylabel("V(s,t)")
    plt.grid(True)
    fn = f"frames/{t:03d}.png"
    plt.savefig(fn); plt.close()
    imageio.imwrite(fn, imageio.imread(fn))

with imageio.get_writer("vf_evolution_a2c.gif", mode="I", duration=0.1) as writer:
    for t in range(T):
        writer.append_data(imageio.imread(f"frames/{t:03d}.png"))

shutil.rmtree("frames")
print("GIF saved as vf_evolution_a2c.gif")
