In [93]:
import torch.nn as nn
import torch
from torch import optim
import copy
import pickle
import random
import numpy as np

from sim_data_generation import  StateActionTransition
from src.envs.envs import BlackOilEnv

WEIGHT, HEIGHT = 80, 40

device = torch.device("cuda")

In [94]:
with open("saved_results_5.pkl", "rb") as f:
    data = pickle.load(f)

res = []

for i in data:
    for j in i:
        res.append(j)

example_fields = [i.state for i in res]
example_dqn = res[1]
example_field = example_dqn.state

In [95]:
example_field

array([[[1.80579588e+00, 1.56305696e+00, 3.25635336e-12, ...,
         9.35601278e-24, 2.82256183e+00, 0.00000000e+00],
        [1.80394850e+00, 1.55744098e+00, 2.43528046e-11, ...,
         5.22175017e-22, 2.80954332e+00, 0.00000000e+00],
        [1.80210807e+00, 1.55167870e+00, 1.61212636e-10, ...,
         2.28362679e-20, 2.79629271e+00, 0.00000000e+00],
        ...,
        [1.77916384e+00, 1.39540629e+00, 3.59415496e-06, ...,
         1.22040102e-11, 2.48265640e+00, 0.00000000e+00],
        [1.77989426e+00, 1.39377292e+00, 1.01229167e-06, ...,
         9.68349180e-13, 2.48076843e+00, 0.00000000e+00],
        [1.78098566e+00, 1.39266113e+00, 2.45708847e-07, ...,
         5.70606246e-14, 2.48030950e+00, 0.00000000e+00]],

       [[1.80381408e+00, 1.55993263e+00, 1.32069747e-11, ...,
         1.54102199e-22, 2.81382845e+00, 0.00000000e+00],
        [1.80174353e+00, 1.55418424e+00, 9.87702476e-11, ...,
         8.60044887e-21, 2.80024139e+00, 0.00000000e+00],
        [1.79966098e+00, 

In [98]:
import importlib
import unet_model
unet_module = importlib.reload(unet_model)

model = unet_model.UNet2()
model.to_cuda()

import numpy as np
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

263692

In [138]:
def create_new_model(ModelClass: unet_model.UNet2):
    model = ModelClass()
    target_model = ModelClass()

    #Загружаем модель на устройство, определенное в самом начале (GPU или CPU)
    model.to_cuda()
    target_model.to_cuda()

    #Сразу зададим оптимизатор, с помощью которого будем обновлять веса модели
    optimizer = optim.Adam(model.parameters(), lr=3e-4)

    return model, target_model, optimizer


def run_model(state, model):
    model_input = torch.FloatTensor(state).to(device)
    model_input = model_input.permute(0, 3, 1, 2)
    model_output = model(model_input)
    model_output = torch.flatten(model_output, start_dim=1)
    return model_output


def select_action(state, model, epsilon=0.05):
    if random.random() < epsilon:
        return [(random.randint(0, WEIGHT-1), random.randint(0, HEIGHT-1))]
    actions_indexes = run_model(state, model).cpu().detach().numpy().argmax(1)

    actions = []
    for index in actions_indexes:
        x = index // WEIGHT
        y = index % HEIGHT
        actions.append((x, y))

    return actions

def get_max_q(state, model):
    model_output = run_model(state, model)

    # находим максимальное значение - это и будет q функция
    return model_output.max(1).values.view(-1)    # [BATCH_SIZE]

In [139]:
input_example = np.stack([example_field, example_field])
input_example = example_field[np.newaxis, ...]

In [140]:
select_action(input_example, model)

[(1, 3)]

In [141]:
get_max_q(input_example, model)

tensor([3.8901], device='cuda:0', grad_fn=<ViewBackward0>)

In [153]:
import torch.nn.functional as F


gamma = 0.99
def fit(batch, model, target_model, optimizer):
    state, action, reward, next_state, done = batch

    # преобразуем внутри функций
    # state = torch.tensor(state).to(device).float()
    # next_state = torch.tensor(next_state).to(device).float()

    state = np.array(state)
    next_state = np.array(next_state)
    reward = torch.tensor(reward).to(device).float()
    action = torch.tensor(action).to(device)
    done = torch.tensor(done).to(device)

    # Считаем то, какие значения должна выдавать наша сеть
    with torch.no_grad():
        # Выбираем максимальное из значений Q-function для следующего состояния
        target_q = get_max_q(next_state, target_model)
        target_q[done] = 0

    target_q = reward + target_q * gamma


    flatten_index = torch.LongTensor([i[1] * HEIGHT + i[0] for i in action]).to(device)
    flatten_index = torch.unsqueeze(flatten_index, -1)

    q = run_model(state, model).gather(1, flatten_index)

    loss = F.mse_loss(q, target_q.unsqueeze(1))

    # Очищаем текущие градиенты внутри сети
    optimizer.zero_grad()

    # Применяем обратное распространение ошибки
    loss.backward()

    # Ограничиваем значения градиента. Необходимо, чтобы обновления не были слишком большими
    for param in model.parameters():
        param.grad.data.clamp_(-1, 1)

    # Делаем шаг оптимизации
    optimizer.step()

    print("model update... Ok")

In [154]:
class Memory:
    def __init__(self, capacity):
        self.capacity = capacity
        self.memory = []
        self.position = 0

    def push(self, element: StateActionTransition):
        """Сохраняет элемент в циклический буфер"""
        if len(self.memory) < self.capacity:
            self.memory.append(None)
        self.memory[self.position] = element
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        """Возвращает случайную выборку указанного размера"""
        return list(zip(*random.sample(self.memory, batch_size)))

    def __len__(self):
        return len(self.memory)

In [155]:
def make_reward_number(reward) -> float:
    if not isinstance(reward, (float, np.float_)):
        return reward[0]
    return reward

In [156]:
mem = Memory(10)
mem.push([1, 2, 3])
mem.push([2, 3, 4])
mem.push([3, 5, 6])
mem.sample(3)

[(1, 2, 3), (2, 3, 5), (3, 4, 6)]

In [157]:
from tqdm import tqdm

sampled_batch = None

def train(env: BlackOilEnv):
    global sampled_batch

    #Количество обновлений model между обновлениями target model
    target_update = 50

    #Размер одного батча, который на вход принимает модель
    batch_size = 4

    #Количество шагов среды
    max_steps = 5000

    #Границы коэффициента exploration
    epsilon = 0.25

    #Создаем модель и буфер
    memory = Memory(2000)
    model, target_model, optimizer = create_new_model(unet_model.UNet2)
    rewards_by_target_updates = []

    env.reset()
    for step in tqdm(range(max_steps)):
        state = env.observation

        #Делаем шаг в среде

        model_input = state[np.newaxis, ...]        # добавляем размерность батча

        action = select_action(model_input, model, epsilon)[0]

        new_state, reward, done = env.step(action)
        reward = make_reward_number(reward)

        #Запоминаем опыт и, если нужно, перезапускаем среду
        memory.push((state, action, reward, new_state, done))
        if done:
              env.reset()

        #Градиентный спуск
        if step > batch_size:
            sampled_batch = memory.sample(batch_size)
            fit(sampled_batch, model, target_model, optimizer)

        if (step+1) % target_update == 0:
            target_model = copy.deepcopy(model)

            #Exploitation
            state = env.reset()
            done = False
            total_reward = 0
            while not done:
                model_input = state[np.newaxis, ...]        # добавляем размерность батча

                action = select_action(model_input, target_model, epsilon=0)[0]

                state, reward, done = env.step(action)
                reward = make_reward_number(reward)
                total_reward += reward
            done = False
            state = env.reset()
            print(f"Testing... Get reward: {total_reward}")
            rewards_by_target_updates.append(total_reward)

    return rewards_by_target_updates

In [158]:
from src.envs.envs import BlackOilEnv
env = BlackOilEnv(days=3)

In [None]:
train(env)

  0%|          | 6/5000 [00:37<8:39:54,  6.25s/it]

model update... Ok


  0%|          | 7/5000 [00:43<8:48:56,  6.36s/it]

model update... Ok


  0%|          | 8/5000 [00:50<8:56:56,  6.45s/it]

model update... Ok


  0%|          | 9/5000 [00:56<8:46:06,  6.32s/it]

model update... Ok


  0%|          | 10/5000 [01:02<8:27:56,  6.11s/it]

model update... Ok


  0%|          | 11/5000 [01:08<8:40:15,  6.26s/it]

model update... Ok


  0%|          | 12/5000 [01:14<8:22:55,  6.05s/it]

model update... Ok


  0%|          | 13/5000 [01:20<8:26:37,  6.10s/it]

model update... Ok


  0%|          | 14/5000 [01:26<8:23:01,  6.05s/it]

model update... Ok


  0%|          | 15/5000 [01:31<8:06:15,  5.85s/it]

model update... Ok


  0%|          | 16/5000 [01:37<8:04:51,  5.84s/it]

model update... Ok


  0%|          | 17/5000 [01:43<8:11:44,  5.92s/it]

model update... Ok


  0%|          | 18/5000 [01:49<8:10:17,  5.90s/it]

model update... Ok


  0%|          | 19/5000 [01:56<8:29:14,  6.13s/it]

model update... Ok


  0%|          | 20/5000 [02:02<8:21:13,  6.04s/it]

model update... Ok


  0%|          | 21/5000 [02:08<8:18:54,  6.01s/it]

model update... Ok


  0%|          | 22/5000 [02:12<7:44:50,  5.60s/it]

model update... Ok


  0%|          | 23/5000 [02:18<7:48:59,  5.65s/it]

model update... Ok


  0%|          | 24/5000 [02:23<7:37:38,  5.52s/it]

model update... Ok


  0%|          | 25/5000 [02:29<7:35:41,  5.50s/it]

model update... Ok


  1%|          | 26/5000 [02:34<7:40:16,  5.55s/it]

model update... Ok


  1%|          | 27/5000 [02:40<7:38:09,  5.53s/it]

model update... Ok


  1%|          | 28/5000 [02:46<7:47:21,  5.64s/it]

model update... Ok


  1%|          | 29/5000 [02:51<7:26:37,  5.39s/it]

model update... Ok


  1%|          | 30/5000 [02:55<7:11:40,  5.21s/it]

model update... Ok


  1%|          | 31/5000 [03:01<7:33:05,  5.47s/it]

model update... Ok


  1%|          | 32/5000 [03:07<7:30:48,  5.44s/it]

model update... Ok


  1%|          | 33/5000 [03:12<7:30:42,  5.44s/it]

model update... Ok


  1%|          | 34/5000 [03:18<7:43:39,  5.60s/it]

model update... Ok


  1%|          | 35/5000 [03:24<7:43:11,  5.60s/it]

model update... Ok


  1%|          | 36/5000 [03:30<7:58:59,  5.79s/it]

model update... Ok


  1%|          | 37/5000 [03:36<8:11:39,  5.94s/it]

model update... Ok


  1%|          | 38/5000 [03:43<8:21:24,  6.06s/it]

model update... Ok


  1%|          | 39/5000 [03:49<8:33:08,  6.21s/it]

model update... Ok


  1%|          | 40/5000 [03:56<8:42:20,  6.32s/it]

model update... Ok


  1%|          | 41/5000 [04:02<8:47:55,  6.39s/it]

model update... Ok


  1%|          | 42/5000 [04:08<8:36:10,  6.25s/it]

model update... Ok


  1%|          | 43/5000 [04:14<8:24:50,  6.11s/it]

model update... Ok


  1%|          | 44/5000 [04:20<8:15:07,  5.99s/it]

model update... Ok


  1%|          | 45/5000 [04:26<8:21:21,  6.07s/it]

model update... Ok


  1%|          | 46/5000 [04:32<8:10:33,  5.94s/it]

model update... Ok


  1%|          | 47/5000 [04:38<8:11:00,  5.95s/it]

model update... Ok


  1%|          | 48/5000 [04:44<8:33:00,  6.22s/it]

model update... Ok


  1%|          | 49/5000 [04:50<8:27:37,  6.15s/it]

model update... Ok
model update... Ok


  1%|          | 50/5000 [05:02<10:45:41,  7.83s/it]

Testing... Get reward: -0.5862330515792511


  1%|          | 51/5000 [05:08<9:54:16,  7.20s/it] 

model update... Ok


  1%|          | 52/5000 [05:14<9:15:31,  6.74s/it]

model update... Ok


  1%|          | 53/5000 [05:19<8:54:54,  6.49s/it]

model update... Ok


  1%|          | 54/5000 [05:26<8:51:50,  6.45s/it]

model update... Ok


  1%|          | 55/5000 [05:31<8:23:39,  6.11s/it]

model update... Ok


  1%|          | 56/5000 [05:37<8:20:18,  6.07s/it]

model update... Ok


  1%|          | 57/5000 [05:42<8:00:47,  5.84s/it]

model update... Ok


  1%|          | 58/5000 [05:49<8:08:03,  5.93s/it]

model update... Ok


  1%|          | 59/5000 [05:55<8:24:00,  6.12s/it]

model update... Ok


  1%|          | 60/5000 [06:01<8:16:20,  6.03s/it]

model update... Ok


  1%|          | 61/5000 [06:07<8:11:22,  5.97s/it]

model update... Ok


  1%|          | 62/5000 [06:13<8:05:17,  5.90s/it]

model update... Ok


  1%|▏         | 63/5000 [06:18<8:06:27,  5.91s/it]

model update... Ok


  1%|▏         | 64/5000 [06:24<7:47:01,  5.68s/it]

model update... Ok


  1%|▏         | 65/5000 [06:29<7:43:37,  5.64s/it]

model update... Ok


  1%|▏         | 66/5000 [06:35<7:45:15,  5.66s/it]

model update... Ok


  1%|▏         | 67/5000 [06:40<7:41:35,  5.61s/it]

model update... Ok


  1%|▏         | 68/5000 [06:46<7:47:08,  5.68s/it]

model update... Ok


  1%|▏         | 69/5000 [06:53<8:05:31,  5.91s/it]

model update... Ok


  1%|▏         | 70/5000 [06:59<8:16:29,  6.04s/it]

model update... Ok


  1%|▏         | 71/5000 [07:05<8:15:46,  6.03s/it]

model update... Ok


  1%|▏         | 72/5000 [07:10<8:01:25,  5.86s/it]

model update... Ok


  1%|▏         | 73/5000 [07:16<8:03:33,  5.89s/it]

model update... Ok


  1%|▏         | 74/5000 [07:22<7:55:16,  5.79s/it]

model update... Ok


  2%|▏         | 75/5000 [07:28<7:56:58,  5.81s/it]

model update... Ok


  2%|▏         | 76/5000 [07:35<8:20:58,  6.10s/it]

model update... Ok


  2%|▏         | 77/5000 [07:41<8:20:18,  6.10s/it]

model update... Ok


  2%|▏         | 78/5000 [07:47<8:15:48,  6.04s/it]

model update... Ok


  2%|▏         | 79/5000 [07:52<7:58:47,  5.84s/it]

model update... Ok


  2%|▏         | 80/5000 [07:58<8:03:14,  5.89s/it]

model update... Ok


  2%|▏         | 81/5000 [08:04<7:54:26,  5.79s/it]

model update... Ok


  2%|▏         | 82/5000 [08:09<7:48:18,  5.71s/it]

model update... Ok


  2%|▏         | 83/5000 [08:15<8:02:21,  5.89s/it]

model update... Ok


  2%|▏         | 84/5000 [08:21<7:44:51,  5.67s/it]

model update... Ok


  2%|▏         | 85/5000 [08:26<7:47:47,  5.71s/it]

model update... Ok


  2%|▏         | 86/5000 [08:32<7:47:05,  5.70s/it]

model update... Ok


  2%|▏         | 87/5000 [08:38<8:00:55,  5.87s/it]

model update... Ok


  2%|▏         | 88/5000 [08:45<8:15:09,  6.05s/it]

model update... Ok


  2%|▏         | 89/5000 [08:50<7:50:57,  5.75s/it]

model update... Ok


  2%|▏         | 90/5000 [08:56<7:55:36,  5.81s/it]

model update... Ok


  2%|▏         | 91/5000 [09:02<7:57:24,  5.84s/it]

model update... Ok


  2%|▏         | 92/5000 [09:07<7:47:00,  5.71s/it]

model update... Ok


  2%|▏         | 93/5000 [09:13<7:54:26,  5.80s/it]

model update... Ok


  2%|▏         | 94/5000 [09:18<7:38:59,  5.61s/it]

model update... Ok


  2%|▏         | 95/5000 [09:23<7:25:38,  5.45s/it]

model update... Ok


  2%|▏         | 96/5000 [09:30<7:46:41,  5.71s/it]

model update... Ok


  2%|▏         | 97/5000 [09:37<8:17:27,  6.09s/it]

model update... Ok


  2%|▏         | 98/5000 [09:43<8:22:40,  6.15s/it]

model update... Ok
