In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.general import get_logger
from utils.test_env import EnvTest
from q3_schedule import LinearExploration, LinearSchedule
from q4_linear_torch import Linear
import logging


from configs.q5_nature import config


class NatureQN(Linear):

    def initialize_models(self):
        state_shape = self.env.state_shape()
        
        img_height, img_width, n_channels = state_shape
        num_actions = self.env.num_actions()

        self.q_network = nn.Sequential(nn.Conv2d(n_channels, 16, kernel_size=3, stride=1),
                                       nn.ReLU(),
                                       nn.Flatten(),
                                       nn.Linear(576,out_features=128),
                                       nn.ReLU(),
                                       nn.Linear(128,out_features=num_actions)
                                       )
        self.target_network = nn.Sequential(nn.Conv2d(n_channels, 16, kernel_size=3, stride=1),
                                       nn.ReLU(),
                                       nn.Flatten(),
                                       nn.Linear(in_features=576,out_features=128),
                                       nn.ReLU(),
                                       nn.Linear(128,out_features=num_actions)
                                       )


    def get_q_values(self, state, network):

        out = None


        for i in range(state.shape[0]):
            if network == 'q_network':
                inp = torch.unsqueeze(state[i].permute((2,0,1)), dim=0)
                out = self.q_network(inp)
                
            elif network == 'target_network':
                inp = torch.unsqueeze(state[i].permute((2,0,1)), dim=0)
                out = self.target_network(input_tensor)
   
        return out


if __name__ == "__main__":
    logging.getLogger(
        "matplotlib.font_manager"
    ).disabled = True  # disable font manager warnings
    env = EnvTest((8, 8, 6))

    # exploration strategy
    exp_schedule = LinearExploration(
        env, config.eps_begin, config.eps_end, config.eps_nsteps
    )

    # learning rate schedule
    lr_schedule = LinearSchedule(config.lr_begin, config.lr_end, config.lr_nsteps)

    # train model
    model = NatureQN(env, config)
    model.run(exp_schedule, lr_schedule, run_idx=1)
