In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import os
import imageio
from tqdm import trange  # progress bar

# -----------------------
# 1. Problem & Cost
# -----------------------
p, a_cost, b_cost = 1, 1, 2
def inventory_cost(s_next, a):
    order_cost = p * a
    inv_cost = a_cost * s_next if s_next >= 0 else b_cost * (-s_next)
    return order_cost + inv_cost

# -----------------------
# 2. Actor–Critic Model
# -----------------------
class ActorCritic(nn.Module):
    def _init_(self, state_dim, num_actions):
        super()._init_()
        self.actor = nn.Sequential(
            nn.Linear(state_dim, 64), nn.ReLU(),
            nn.Linear(64, 64),         nn.ReLU(),
            nn.Linear(64, num_actions)
        )
        self.critic = nn.Sequential(
            nn.Linear(state_dim, 64), nn.ReLU(),
            nn.Linear(64, 64),         nn.ReLU(),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        logits = self.actor(x)
        value  = self.critic(x)
        return logits, value

    def get_action(self, state):
        logits, value = self.forward(state)
        dist   = Categorical(logits=logits)
        a      = dist.sample()
        logp   = dist.log_prob(a)
        return a.item(), logp, value.squeeze()

# -----------------------
# 3. Hyperparameters
# -----------------------
num_actions  = 10
T            = 100
min_inv      = -100
max_inv      = 100
state_dim    = 2
gamma        = 1.0
lr           = 1e-3
num_episodes = 2000

# -----------------------
# 4. Setup
# -----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
agent  = ActorCritic(state_dim, num_actions).to(device)
opt    = optim.Adam(agent.parameters(), lr=lr)

# -----------------------
# 5. Training Loop
# -----------------------
for ep in trange(num_episodes, desc="Training"):
    # start from a random inventory level
    s = random.randint(min_inv, max_inv)
    for t in range(T):
        # normalize state into [–1,1]×[0,1]
        st      = torch.tensor([s/max_inv, t/(T-1)],
                               dtype=torch.float32, device=device)
        a, logp, v = agent.get_action(st)

        # step in environment
        w   = random.randint(0, 10)
        s2  = max(min_inv, min(max_inv, s + a - w))
        cost= inventory_cost(s2, a)

        # estimate value of next state
        with torch.no_grad():
            t2   = 0 if t == T-1 else (t+1)
            st2  = torch.tensor([s2/max_inv, t2/(T-1)],
                                dtype=torch.float32, device=device)
            _, v2 = agent.forward(st2)
            v2     = v2.squeeze()

        done_mask = float(t == T-1)
        # advantage = (immediate cost + γ·V(next)) – V(current)
        advantage = cost + gamma * v2 * (1 - done_mask) - v

        # actor: increase log-prob of actions that lower cost
        actor_loss  = logp * advantage.detach()
        critic_loss = advantage.pow(2)
        loss        = actor_loss + critic_loss

        opt.zero_grad()
        loss.backward()
        opt.step()

        s = s2

# -----------------------
# 6. Build Policy & Value Tables
# -----------------------
agent.eval()
num_states   = max_inv - min_inv + 1
policy_table = np.zeros((num_states, T), dtype=int)
V_table      = np.zeros((num_states, T), dtype=float)

with torch.no_grad():
    for idx, s in enumerate(range(min_inv, max_inv+1)):
        for t in range(T):
            st     = torch.tensor([[s/max_inv, t/(T-1)]],
                                  dtype=torch.float32, device=device)
            logits, v = agent.forward(st)
            # pick cheapest action = smallest logit
            policy_table[idx, t] = int(logits.argmin(dim=-1).item())
            V_table[idx, t]      = v.item()

# -------------------------
# 7. Compute Base-Stock σ_t
# -------------------------
sigma = []
for t in range(T):
    row = policy_table[:, t]
    th  = max_inv
    for i, a in enumerate(row):
        s_val = i + min_inv
        if a == 0:        # action 0 = “no order”
            th = s_val
            break
    sigma.append(th)

print("Period | σ_t")
print("---------------")
for t, z in enumerate(sigma):
    print(f"{t:6d} | {z:4d}")

# -----------------------
# 8. Simulate & Print One Rollout
# -----------------------
print("\nTime | Inv | σ_t | Demand | Action")
print("------------------------------------")
s = 0
for t in range(T):
    dem = random.randint(0, 10)
    a   = policy_table[s - min_inv, t]
    print(f"{t:4d} | {s:4d} | {sigma[t]:4d} | {dem:6d} | {a:6d}")
    s = max(min_inv, min(max_inv, s + a - dem))

# -----------------------
# 9. Plot V(s) Evolution as GIF
# -----------------------
os.makedirs('frames', exist_ok=True)
frames = []
for t in range(T):
    plt.figure(figsize=(5,3))
    plt.plot(range(min_inv, max_inv+1),
             V_table[:, t], marker='o')
    plt.title(f"V(s) at t={t}")
    plt.xlabel("Inventory s")
    plt.ylabel("V(s,t)")
    plt.grid(True)
    fn = f'frames/frame_{t:03d}.png'
    plt.savefig(fn)
    plt.close()
    frames.append(fn)

with imageio.get_writer('vf_evolution_ac_2.gif', mode='I', duration=0.1) as writer:
    for fn in frames:
        writer.append_data(imageio.v2.imread(fn))

# clean up temp frames
import shutil
shutil.rmtree('frames')

print("GIF saved as 'vf_evolution_ac_2.gif'")