In [1]:
import datetime
import wandb
from torch.utils.data import random_split, DataLoader
import argparse
import yaml
import sys
import torch.nn as nn
from torch.autograd import Variable
import torch
import numpy as np
import torch.nn.functional as F
import copy

sys.path.append('../')
from VizDoom.VizDoom_src.utils import get_vizdoom_iter_dataset, ViZDoomIterDataset
from VizDoom.VizDoom_src.train import trainer

import os
import sys

import pickle
from tqdm import tqdm
#import env_vizdoom2
import matplotlib.pyplot as plt
from itertools import count
import time
import random

os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"   
os.environ["CUDA_VISIBLE_DEVICES"]="0"
# os.environ["CUDA_LAUNCH_BLOCKING"]="1"

os.environ["MKL_NUM_THREADS"] = "1" 
os.environ["NUMEXPR_NUM_THREADS"] = "1" 
os.environ["OMP_NUM_THREADS"] = "1" 
os.environ["TORCH_USE_CUDA_DSA"] = "1" 

In [2]:
with open("../VizDoom/VizDoom_src/config.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

config["training_config"]["batch_size"] = 128
max_length = config["training_config"]["sections"]*config["training_config"]["context_length"]
print(f"{max_length=}")

max_length=90


In [3]:
%load_ext autoreload
%autoreload 2

from lstm_agent_cql import DecisionLSTM

agent = DecisionLSTM(4, 1, 128, mode='doom')

# PATH = f'DOOM_BC_270_sar'
# weights = torch.load(f"{PATH}.ckpt", map_location="cpu")

# agent.load_state_dict(weights, strict=True)
agent.train()
agent.to(agent.device)

optimizer = torch.optim.AdamW(agent.parameters(), lr=config["training_config"]["learning_rate"], 
                                      weight_decay=config["training_config"]["weight_decay"], 
                                      betas=(config["training_config"]["beta_1"], config["training_config"]["beta_2"]))

In [4]:
path_to_splitted_dataset = '../../../RATE/VizDoom/VizDoom_data/iterative_data/'
train_dataset = ViZDoomIterDataset(path_to_splitted_dataset, 
                                 gamma=config["data_config"]["gamma"], 
                                 max_length=max_length, 
                                 normalize=config["data_config"]["normalize"])

train_dataloader = DataLoader(train_dataset, 
                             batch_size=config["training_config"]["batch_size"],
                             shuffle=True, 
                             num_workers=8)

Filtering data...


100%|██████████| 5000/5000 [00:27<00:00, 184.58it/s]


In [5]:
with open("../wandb_config.yaml") as f:
    wandb_config = yaml.load(f, Loader=yaml.FullLoader)
os.environ['WANDB_API_KEY'] = wandb_config['wandb_api']

# os.environ['WANDB_API_KEY'] = 'WANDB_API_KEY'
EXP_NAME = 'doom_cql_sar'
wandb.init(project="RATE_DOOM_CQL", name=f'{EXP_NAME}')

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mcherepanovegor2018[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [6]:
criterion_all = nn.CrossEntropyLoss(ignore_index=-10, reduction='mean')
agent.train()

CQL_ALPHA = 1.0
DISCOUNT = 0.99
#TARGET_UPDATE_FREQ = 10
TAU = 0.005  # Soft update parameter
TARGET_UPDATE_FREQ = 10

target_q1 = copy.deepcopy(agent.q1)
target_q2 = copy.deepcopy(agent.q2)

for param in target_q1.parameters():
    param.requires_grad = False
for param in target_q2.parameters():
    param.requires_grad = False

for epochs in range(3600):
    
    agent.train()
    
    for it, batch in enumerate(train_dataloader):
        s, a, rtg, d, timesteps, masks = batch
        d[d==2] = 1.
        d = 1-d
        d = d.unsqueeze(-1).cuda()
        s = s.cuda()
        a = a.cuda()
        rtg = rtg.cuda().float()

        agent.init_hidden(s.shape[0])

        action_preds, q1_pred, q2_pred, cql_loss = agent(s,a,rtg, stacked_input=True)

        with torch.no_grad():
            next_s = s[:, 1:]
            next_a = a[:, 1:]
            next_rtg = rtg[:, 1:]

            _, next_q1, next_q2, _ = agent(next_s, next_a, next_rtg, stacked_input=True)
            next_q = torch.min(next_q1, next_q2)
            target_q = rtg[:, :-1] + DISCOUNT * (1 - d[:, :-1]) * next_q

        q1_loss = F.mse_loss(q1_pred[:, :-1], target_q)
        q2_loss = F.mse_loss(q2_pred[:, :-1], target_q)   

    #     break
    # break
    
        action_preds = action_preds.reshape(-1, action_preds.size(-1))
        target_actions = a.reshape(-1).long()
        bc_loss = criterion_all(action_preds, target_actions)
        
        total_loss = q1_loss + q2_loss + cql_loss + bc_loss
        #total_loss = bc_loss
    
        optimizer.zero_grad()
        total_loss.backward()#retain_graph=False)
        torch.nn.utils.clip_grad_norm_(agent.parameters(), config["training_config"]["grad_norm_clip"])
        optimizer.step()
    
    if it % 1 == 0:
        with torch.no_grad():
            for param, target_param in zip(agent.q1.parameters(), target_q1.parameters()):
                target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)
            for param, target_param in zip(agent.q2.parameters(), target_q2.parameters()):
                target_param.data.copy_(TAU * param.data + (1 - TAU) * target_param.data)
        
    print(f'Epochs: {epochs} It: {it} Train Loss: {total_loss.item()} '
          f'(BC: {bc_loss.item():.4f}, Q1: {q1_loss.item():.4f}, '
          f'Q2: {q2_loss.item():.4f}, CQL: {cql_loss.item():.4f})')

    wandb.log({'BC':bc_loss.item()})
    wandb.log({'Q1':q1_loss.item()})
    wandb.log({'Q2':q2_loss.item()})
    wandb.log({'CQL':cql_loss.item()})


    if epochs%10==0:
        PATH = f'./ckpt/Doom_lstm_SAR_90_CQL'
        os.makedirs(PATH, exist_ok=True)
        torch.save(agent.state_dict(),f"{PATH}.ckpt")

Epochs: 0 It: 38 Train Loss: 14.58031177520752 (BC: 0.8446, Q1: 2.7800, Q2: 2.2558, CQL: 8.7000)
Epochs: 1 It: 38 Train Loss: 11.613214492797852 (BC: 0.7044, Q1: 0.9505, Q2: 0.9202, CQL: 9.0381)
Epochs: 2 It: 38 Train Loss: 10.923055648803711 (BC: 0.6561, Q1: 0.7344, Q2: 0.7383, CQL: 8.7943)
Epochs: 3 It: 38 Train Loss: 10.676511764526367 (BC: 0.6502, Q1: 0.6661, Q2: 0.6743, CQL: 8.6859)
Epochs: 4 It: 38 Train Loss: 10.935297012329102 (BC: 0.6430, Q1: 0.8107, Q2: 0.8106, CQL: 8.6710)
Epochs: 5 It: 38 Train Loss: 10.791973114013672 (BC: 0.6394, Q1: 0.7971, Q2: 0.7980, CQL: 8.5574)
Epochs: 6 It: 38 Train Loss: 10.514238357543945 (BC: 0.6417, Q1: 0.7139, Q2: 0.7052, CQL: 8.4534)
Epochs: 7 It: 38 Train Loss: 10.358912467956543 (BC: 0.6402, Q1: 0.6507, Q2: 0.6335, CQL: 8.4345)
Epochs: 8 It: 38 Train Loss: 10.652642250061035 (BC: 0.6365, Q1: 0.7632, Q2: 0.7652, CQL: 8.4877)
Epochs: 9 It: 38 Train Loss: 10.682952880859375 (BC: 0.6421, Q1: 0.7503, Q2: 0.7431, CQL: 8.5474)
Epochs: 10 It: 38 Tra

KeyboardInterrupt: 

# Test

In [5]:
sys.path.append('../VizDoom/VizDoom_notebooks/')
from VizDoom.VizDoom_notebooks.doom_environment2 import DoomEnvironment
import env_vizdoom2

In [6]:
env_args = {
    'simulator':'doom', 
    'scenario':'custom_scenario{:003}.cfg', #custom_scenario_no_pil{:003}.cfg
    'test_scenario':'', 
    'screen_size':'320X180', 
    'screen_height':64, 
    'screen_width':112, 
    'num_environments':16,# 16
    'limit_actions':True, 
    'scenario_dir':'../VizDoom/VizDoom_src/env/',
    'test_scenario_dir':'', 
    'show_window':False, 
    'resize':True, 
    'multimaze':True, 
    'num_mazes_train':16, 
    'num_mazes_test':1, # 64 
    'disable_head_bob':False, 
    'use_shaping':False, 
    'fixed_scenario':False, 
    'use_pipes':False, 
    'num_actions':0, 
    'hidden_size':128, 
    'reload_model':'', 
    'model_checkpoint':'./VizDoom/VizDoom_notebooks/two_col_p1_checkpoint_0198658048.pth.tar',   # two_col_p0_checkpoint_0049154048.pth.tar',  #two_col_p0_checkpoint_0198658048.pth.tar', 
    'conv1_size':16, 
    'conv2_size':32, 
    'conv3_size':16, 
    'learning_rate':0.0007, 
    'momentum':0.0, 
    'gamma':0.99, 
    'frame_skip':4, 
    'train_freq':4, 
    'train_report_freq':100, 
    'max_iters':5000000, 
    'eval_freq':1000, 
    'eval_games':50, 
    'model_save_rate':1000, 
    'eps':1e-05, 
    'alpha':0.99, 
    'use_gae':False, 
    'tau':0.95, 
    'entropy_coef':0.001, 
    'value_loss_coef':0.5, 
    'max_grad_norm':0.5, 
    'num_steps':128, 
    'num_stack':1, 
    'num_frames':200000000, 
    'use_em_loss':False, 
    'skip_eval':False, 
    'stoc_evals':False, 
    'model_dir':'', 
    'out_dir':'./', 
    'log_interval':100, 
    'job_id':12345, 
    'test_name':'test_000', 
    'use_visdom':False, 
    'visdom_port':8097, 
    'visdom_ip':'http://10.0.0.1'                 
}


In [7]:
scene = 0
scenario = env_args['scenario_dir'] + env_args['scenario'].format(scene) # 0 % 63
config = scenario
device = 'cuda:0'

env = env_vizdoom2.DoomEnvironmentDisappear(
    scenario=config,
    show_window=False,
    use_info=True,
    use_shaping=False, #if False bonus reward if #shaping reward is always: +1,-1 in two_towers
    frame_skip=2,
    no_backward_movement=True)

In [7]:
# PATH = f'Doom_lstm_S_90_CQL'
PATH = f'ckpt/1/Doom_lstm_SAR_90_CQL'
# PATH = f'Doom_lstm_SAR_90_CQL'

weights = torch.load(f"{PATH}.ckpt", map_location="cpu")

agent.load_state_dict(weights, strict=True)
agent.train()
agent.to(agent.device)

DecisionLSTM(
  (lstm): LSTM(128, 128, num_layers=2, batch_first=True)
  (predict_action): Linear(in_features=128, out_features=5, bias=False)
  (embed_state): Sequential(
    (0): Conv2d(3, 32, kernel_size=(8, 8), stride=(4, 4))
    (1): ReLU()
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2))
    (3): ReLU()
    (4): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1))
    (5): ReLU()
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=2560, out_features=128, bias=True)
    (8): Tanh()
  )
  (embed_action_toq): Sequential(
    (0): Embedding(5, 5)
    (1): Tanh()
  )
  (embed_action): Sequential(
    (0): Embedding(5, 128)
    (1): Tanh()
  )
  (embed_return): Sequential(
    (0): Linear(in_features=1, out_features=128, bias=True)
    (1): Tanh()
  )
  (q1): Sequential(
    (0): Linear(in_features=129, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_

In [10]:
EPISODE_TIMEOUT = 4200 # 90
CQL = False
stacked_input = 'SAR' in PATH

NUMBER_OF_TRAIN_DATA = 100
returns_red, returns_green = [], []
agent.eval()

for i in tqdm(range(NUMBER_OF_TRAIN_DATA)):
    obsList, actList, rewList, doneList, isRedList = [], [], [], [], []
    times = []
    obs = env.reset()
    state = torch.zeros(1, env_args['hidden_size']).to(device)
    mask = torch.ones(1,1).to(device)
    done = False
    agent.init_hidden(1)
    action = 0
    rtg = 60.

    for t in count():
        times.append(t)
        obsList.append(obs['image'])
        #result = policy(torch.from_numpy(obs['image']).unsqueeze(0).to(device), state, mask)
        #action, state = result['actions'], result['states']

        states = torch.from_numpy(obs['image']).unsqueeze(0).unsqueeze(0).to(device)

        with torch.no_grad():
            q_values = []
            for possible_action in range(0,5):  # 5 возможных действия
                action_tensor = torch.tensor([[[possible_action]]], 
                                           dtype=torch.float32, 
                                           device=device).long()
                rtg_tensor = torch.tensor([[[rtg]]], 
                                           dtype=torch.float32, 
                                           device=device)#.long()
                if CQL:
                    update_lstm_hidden = possible_action==4
                else:
                    update_lstm_hidden = True
                    
                action_preds, q1, q2, _ = agent.forward(
                    states = states,
                    actions = action_tensor,
                    returns_to_go = rtg_tensor,
                    update_hidden = update_lstm_hidden,
                    stacked_input = stacked_input,
                )
                q_value = torch.minimum(q1, q2)
                q_values.append(q_value)

                if not CQL:
                    break

            # Select action with max Q-value
            if CQL:
                q_values = torch.cat(q_values, dim=-1)
                action = torch.argmax(q_values).item() #+ 3
            else:
                action = torch.argmax(torch.softmax(action_preds, dim=-1).squeeze()).item()

        #action = random.choice([3,4])
        #print(t,action, q_values)
        obs, reward, done, info = env.step(action)
        rtg -= reward

        is_red = info['is_red']
        rewList.append(reward)
        actList.append(action)
        doneList.append(int(done))
        isRedList.append(is_red)

        if done or t == EPISODE_TIMEOUT-1:

            if is_red == 1.0:
                returns_red.append(np.sum(rewList))
            else:
                returns_green.append(np.sum(rewList))

            break


100%|██████████| 100/100 [04:50<00:00,  2.90s/it]


In [11]:
print(np.mean(returns_red))

41.575116279069775


In [12]:
print(np.mean(returns_green))

55.78526315789474


In [14]:
# BC S
# 20.574042553191493
# 20.609433962264156

# BC SAR
# 5.929761904761905
# 6.481379310344828

# CQL S
# 22.045434782608698
# 27.417037037037034

# CQL SAR
# 4.904339622641509
# 5.614893617021276

In [13]:
for i in range(1, 6+1):
    # PATH = f'Doom_lstm_S_90_CQL'
    # PATH = f'ckpt/{i}/Doom_lstm_S_90_CQL'
    PATH = f'ckpt/{i}/Doom_lstm_S_90_BC'
    # PATH = f'Doom_lstm_SAR_90_CQL'

    weights = torch.load(f"{PATH}.ckpt", map_location="cpu")

    agent.load_state_dict(weights, strict=True)
    agent.train()
    agent.to(agent.device)


    EPISODE_TIMEOUT = 4200 # 90
    CQL = False
    stacked_input = 'SAR' in PATH

    NUMBER_OF_TRAIN_DATA = 100
    returns_red, returns_green, returns_total = [], [], []
    agent.eval()

    for i in tqdm(range(NUMBER_OF_TRAIN_DATA)):
        obsList, actList, rewList, doneList, isRedList = [], [], [], [], []
        times = []
        obs = env.reset()
        state = torch.zeros(1, env_args['hidden_size']).to(device)
        mask = torch.ones(1,1).to(device)
        done = False
        agent.init_hidden(1)
        action = 0
        rtg = 60.

        for t in count():
            times.append(t)
            obsList.append(obs['image'])
            #result = policy(torch.from_numpy(obs['image']).unsqueeze(0).to(device), state, mask)
            #action, state = result['actions'], result['states']

            states = torch.from_numpy(obs['image']).unsqueeze(0).unsqueeze(0).to(device)

            with torch.no_grad():
                q_values = []
                for possible_action in range(0,5):  # 5 возможных действия
                    action_tensor = torch.tensor([[[possible_action]]], 
                                            dtype=torch.float32, 
                                            device=device).long()
                    rtg_tensor = torch.tensor([[[rtg]]], 
                                            dtype=torch.float32, 
                                            device=device)#.long()
                    if CQL:
                        update_lstm_hidden = possible_action==4
                    else:
                        update_lstm_hidden = True
                        
                    action_preds, q1, q2, _ = agent.forward(
                        states = states,
                        actions = action_tensor,
                        returns_to_go = rtg_tensor,
                        update_hidden = update_lstm_hidden,
                        stacked_input = stacked_input,
                    )
                    q_value = torch.minimum(q1, q2)
                    q_values.append(q_value)

                    if not CQL:
                        break

                # Select action with max Q-value
                if CQL:
                    q_values = torch.cat(q_values, dim=-1)
                    action = torch.argmax(q_values).item() #+ 3
                else:
                    action = torch.argmax(torch.softmax(action_preds, dim=-1).squeeze()).item()

            #action = random.choice([3,4])
            #print(t,action, q_values)
            obs, reward, done, info = env.step(action)
            rtg -= reward

            is_red = info['is_red']
            rewList.append(reward)
            actList.append(action)
            doneList.append(int(done))
            isRedList.append(is_red)

            if done or t == EPISODE_TIMEOUT-1:

                if is_red == 1.0:
                    returns_red.append(np.sum(rewList))
                else:
                    returns_green.append(np.sum(rewList))

                returns_total.append(np.sum(rewList))

                break

    print(f"\nResults for checkpoint {i}:")
    print(f"Red team average return:   {np.mean(returns_red):.2f}")
    print(f"Green team average return: {np.mean(returns_green):.2f}")
    print(f"Total average return:      {np.mean(returns_total):.2f}")
    print("-" * 50)

100%|██████████| 100/100 [05:22<00:00,  3.23s/it]



Results for checkpoint 99:
Red team average return:   36.28
Green team average return: 73.71
Total average return:      55.74
--------------------------------------------------


100%|██████████| 100/100 [04:59<00:00,  2.99s/it]



Results for checkpoint 99:
Red team average return:   23.96
Green team average return: 77.50
Total average return:      52.33
--------------------------------------------------


100%|██████████| 100/100 [04:03<00:00,  2.43s/it]



Results for checkpoint 99:
Red team average return:   25.21
Green team average return: 59.70
Total average return:      42.11
--------------------------------------------------


100%|██████████| 100/100 [03:54<00:00,  2.35s/it]



Results for checkpoint 99:
Red team average return:   40.58
Green team average return: 41.17
Total average return:      40.91
--------------------------------------------------


100%|██████████| 100/100 [02:41<00:00,  1.62s/it]



Results for checkpoint 99:
Red team average return:   9.84
Green team average return: 41.12
Total average return:      26.11
--------------------------------------------------


100%|██████████| 100/100 [04:24<00:00,  2.65s/it]


Results for checkpoint 99:
Red team average return:   38.78
Green team average return: 49.66
Total average return:      44.44
--------------------------------------------------



