In [21]:
# imports
import os
import json
import re

import torch 
from torch.utils import data
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime

import numpy as np

from MDP import MDP

import stable_baselines3
import sb3_contrib

import gym

In [22]:
# check torch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3080'

In [23]:
# create Neural Network

class Net(nn.Module):
    """
    input : 2 X 4 X 4 grid
    label : Move [0;6]
    """
    def __init__(self):
        super(Net, self).__init__()
        # first layer: input
        self.conv1 = nn.Conv2d(2, 8, 2)

        #second layer : 2nd convolution
        self.conv2 = nn.Conv2d(8, 16, 2)

        self.conv3 = nn.Conv2d(16, 32, 2)

        self.fc1 = nn.Linear(32, 32)

        self.out = nn.Linear(32, 6)


    def forward(self, x):
        x = x.float()

        x = F.relu(self.conv1(x))

        x = F.relu(self.conv2(x))

        x = F.relu(self.conv3(x))

        x = torch.flatten(x,start_dim=1)

        x = F.relu(self.fc1(x))

        x = self.out(x)
    
        return x    

In [24]:
#creating model
net = Net()
net.cuda()
print(net)

params = list(net.parameters())
print(f"number of parameters: {len(params)}")

#loss function
loss = nn.CrossEntropyLoss()

#optimizer
optimizer = torch.optim.Adam(net.parameters())
optimizer

Net(
  (conv1): Conv2d(2, 8, kernel_size=(2, 2), stride=(1, 1))
  (conv2): Conv2d(8, 16, kernel_size=(2, 2), stride=(1, 1))
  (conv3): Conv2d(16, 32, kernel_size=(2, 2), stride=(1, 1))
  (fc1): Linear(in_features=32, out_features=32, bias=True)
  (out): Linear(in_features=32, out_features=6, bias=True)
)
number of parameters: 10


Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    eps: 1e-08
    lr: 0.001
    weight_decay: 0
)

In [25]:
net.load_state_dict(torch.load("Net"))

<All keys matched successfully>

In [27]:
#custom environment
from gym import spaces

class Gridworld(gym.Env):

    metadata = {"render.modes" : ["human"]}

    def MDP_generator(self):
        while True:
            for dir in self.dir:
                print(dir)
                for type in self.type:
                    for i in os.listdir(os.sep.join(["datasets", dir, type, "task"]))[:-4]:
                        i = re.sub(r"\D", "", i)
                        yield dir, type, i

    def __init__(self, dir = ["data", "data_easy", "data_medium"], type = ["train"], lambda1 = 0.01, lambda2 = 0.1, lambda3 = 1) -> None:
        super(Gridworld, self).__init__()
        self.action_space = spaces.Discrete(6)
        self.observation_space = spaces.Box(low = 0, high = 10, shape = (2, 4, 4))

        #available MDPs
        self.dir = dir
        self.type = type
        self.lambda1 = lambda1
        self.lambda2 = lambda2
        self.lambda3 = lambda3

        self.next_MDP = self.MDP_generator()
        self.actions = ["move", "turnLeft", "turnRight", "pickMarker", "putMarker", "finish"]

    def reset(self):
        nextDir, nextType, nexti =  next(self.next_MDP)
        self.currentMDP = MDP(nextDir, nextType, nexti, lambda1= self.lambda1, lambda2 = self.lambda2, lambda3 =self.lambda3)
        self.steps = 0
        return self.currentMDP.get_current_state()

    def step(self, action):
        nextState, rew, done, info = self.currentMDP.sample_next_state_and_reward(self.actions[action])
        self.steps += 1
        if self.steps > 500:
            return nextState, -1, True, info 

        return nextState, rew - 0.01, done, info 
        
    def render(self):
        self.currentMDP.print_grid()

    def close(self):
        pass
    
    def action_masks(self):
        mat = self.currentMDP.get_current_state()
        if np.array_equal(mat[0], mat[1]):
            return np.array([0,0,0,0,0,1])
        
        return self.currentMDP.action_mask()

    # functions bellow are only used for inheritance 
    def get_MDP(self):
        return self.currentMDP

    def get_MDP_name(self):
        return self.nextDir, self.nextType, self.nexti

In [34]:
def test_RL_models(model):
    valDataset = Gridworld(type = ["val"], lambda1=0, lambda2=0)
    correct, total = 0,5200
    for task in range(int(total)):
        if task % 500 == 499:
            print(f"{(task+1) / total *100} %, running acc: {(correct*100)/(task+1)}")
        currMDP = valDataset.reset()
        done = False
        steps = 0
        while not done and steps < 50:
            action = model.predict(currMDP, action_masks = valDataset.action_masks(), deterministic = True)[0]
            currMDP, rew, done, _ = valDataset.step(action)
            if rew > 0:
                correct += 1
            steps += 1
            
            
    print(f"correct : {correct}, accuracy: {(correct*100)/total } %")
    return (correct*100)/total

In [29]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from sb3_contrib.ppo_mask import MaskablePPO
from stable_baselines3.common.env_util import make_vec_env

import torch.nn.functional as F
from torch import nn

class CustomFeatureExtractorTorch(BaseFeaturesExtractor):
    def __init__(self, observation_space: gym.Space, features_dim: int = 32):
        super().__init__(observation_space, features_dim)

        self.conv1 = net.conv1
        self.conv2 = net.conv2
        self.conv3 = net.conv3

        #additional convolutions


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = torch.flatten(x, start_dim=1)
        return x

net_arch = [
    32, 16, 8,
    dict(vf = [16, 8, 4], pi = [16, 8, 4])
]

policy_kwargs = dict(
    features_extractor_class = CustomFeatureExtractorTorch,
    net_arch = net_arch
)

In [30]:
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
from sb3_contrib.ppo_mask import MaskablePPO
from stable_baselines3.common.env_util import make_vec_env
from sb3_contrib.common.maskable.policies import MaskableActorCriticPolicy

In [31]:
net_arch = [
    128, 64, 16, 8,
    dict(vf = [8, 4, 2], pi = [8, 4, 2])
]

policy_kwargs = dict(
    features_extractor_class = CustomFeatureExtractorTorch,
    net_arch = net_arch
)

In [32]:
FinalEnv = make_vec_env(Gridworld, n_envs= 4, env_kwargs={"lambda1" : 0, "lambda2" : 0, "dir" : ["data_easy", "generated_easy", "data_medium", "generated_med", "data", "data_hard"]})

FinalModel = MaskablePPO(MaskableActorCriticPolicy, FinalEnv, policy_kwargs= policy_kwargs,  verbose = 1, gamma = 0.75, n_steps= 500)

FinalModel.learn(3*1e5)

Using cuda device
data_easy
data_easy
data_easy
data_easy
---------------------------------
| rollout/           |          |
|    ep_len_mean     | 4        |
|    ep_rew_mean     | 0.96     |
| time/              |          |
|    fps             | 2145     |
|    iterations      | 1        |
|    time_elapsed    | 0        |
|    total_timesteps | 2000     |
---------------------------------
-------------------------------------------
| rollout/                |               |
|    ep_len_mean          | 184           |
|    ep_rew_mean          | -1.57         |
| time/                   |               |
|    fps                  | 1100          |
|    iterations           | 2             |
|    time_elapsed         | 3             |
|    total_timesteps      | 4000          |
| train/                  |               |
|    approx_kl            | 0.00067855447 |
|    clip_fraction        | 0             |
|    clip_range           | 0.2           |
|    entropy_loss         | -1

<sb3_contrib.ppo_mask.ppo_mask.MaskablePPO at 0x7fea61909df0>

In [33]:
test_RL_models(FinalModel)

data
3.3333333333333335 %, running acc: 29.0
6.666666666666667 %, running acc: 28.6
10.0 %, running acc: 28.0
13.333333333333334 %, running acc: 28.25
data_easy
16.666666666666664 %, running acc: 30.52
data_medium
20.0 %, running acc: 35.766666666666666
23.333333333333332 %, running acc: 36.34285714285714
data
26.666666666666668 %, running acc: 36.625
30.0 %, running acc: 35.8
33.33333333333333 %, running acc: 35.1
36.666666666666664 %, running acc: 34.30909090909091
40.0 %, running acc: 33.833333333333336
data_easy
43.333333333333336 %, running acc: 34.38461538461539
data_medium
46.666666666666664 %, running acc: 36.27142857142857
50.0 %, running acc: 36.53333333333333
data
53.333333333333336 %, running acc: 36.675
56.666666666666664 %, running acc: 36.21176470588235
60.0 %, running acc: 35.8


KeyboardInterrupt: 