In [30]:
from wovenv.ai.q_learning import QLearningAgent
from wovenv.venv.env import Env
from IPython.display import clear_output
import matplotlib.pyplot as plt
import numpy as np

# Окружение и агент

In [37]:
env = Env()
agent = QLearningAgent(alpha=0.5, discount=0.1)

In [32]:
def generate_session(env: Env, agent: QLearningAgent, t_max=1000, train=False):
    env.write_log = not train
    total_reward = 0
    s = env.reset()

    game = []

    for t in range(t_max):
        a = agent.get_action(s)
        next_s, r, done = env.step(a)

        game.append((s, a, next_s, r, done))

        total_reward += r
        s = next_s

        if done:
            if not train: env.finish_log()
            break
    
    loss = None
    if train:
        for s, a, next_s, r, done in game:
            loss = agent.update(s, a, next_s, r, done)

    return total_reward, loss

# Визуализация

In [33]:
def plot(rewards, losses):
    clear_output(True)
    plt.figure(figsize=(20,5))
    plt.subplot(131)
    plt.title('reward: %s' % (np.mean(rewards[-10:])))
    plt.plot(rewards)
    plt.subplot(132)
    plt.title('loss')
    plt.plot(losses)
    plt.show()

# Горячий старт (при необходимости)

In [34]:
def hot_start(env: Env, agent: QLearningAgent, ind, t_max=1000):
    clear_output()
    print(f'start {ind}')
    s = env.reset()

    for t in range(t_max):
        acs = env._get_actions()
        for a in acs:
            next_s, r, done = env.step(a)
            agent.update(s, a, next_s, r, done)
            s = next_s
            if done: return

In [35]:
#for i in range(1000):
#    hot_start(env, agent, i)

# Обучение

In [None]:
rewards, losses = [], []

for _ in range(100):
    agent.epsilon = 1
    while agent.epsilon > 0.1:
        reward, loss = generate_session(env,agent,train=True)
        rewards.append(reward)
        losses.append(loss)
        agent.epsilon *= 0.999
        plot(rewards, losses)

In [None]:
losses

# Сериализация

In [None]:
import pickle

---
### Сохранение обученного агента

In [None]:
with open('data.pickle', "wb") as f:
    pickle.dump(agent, f)

---
### Загрузка сохраненного агента

In [None]:
agent = pickle.load(open('data.pickle', 'rb'))

# Генерация игры с записью в лог для последующей визуализации

In [None]:
agent.epsilon = 0
generate_session(env, agent)