In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.test_env import EnvTest
from core.deep_q_learning_torch import DQN
from q3_schedule import LinearExploration, LinearSchedule
import numpy as np
from configs.q4_linear import config
import logging


class Linear(DQN):

    def initialize_models(self):
        state_shape = self.env.state_shape()
        img_height, img_width, n_channels = state_shape
        num_actions = self.env.num_actions()
        self.q_network  = nn.Linear(img_height*img_width*n_channels*self.config.state_history, num_actions)
        self.target_network = nn.Linear(img_height*img_width*n_channels*self.config.state_history, num_actions)

    def get_q_values(self, state: torch.Tensor, network: str = "q_network"):
        out = None
        
        out = torch.Tensor([])
        for i in range(state.shape[0]):
            input_tensor = torch.flatten(state[i,:,:,:])
            if network == 'q_network':
                if len(out) == 0:
                    out = self.q_network(input_tensor)
                else:
                    res = self.q_network(input_tensor)
                    out = torch.cat((out, res))            
            elif network == 'target_network':
                if len(out) == 0:
                    torch.Tensor([]) = self.target_network(input_tensor)
                else:
                    res = self.target_network(input_tensor)
                    out = torch.cat((out, res))
        
        out = torch.reshape(out, shape=[state.shape[0], self.env.num_actions()])
        return out

    def update_target(self):

        torch.save(self.q_network,'./temp_model_weights.pt')
        self.target_network = torch.load('./temp_model_weights.pt')

    def calc_loss(
        self,
        q_values: torch.Tensor,
        target_q_values: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        done_mask: torch.Tensor,
    ) -> torch.Tensor:

        num_actions = self.env.num_actions()
        gamma = self.config.gamma
        max_v = torch.max(target_q_values,dim=1).values
        q_samp = torch.where(done_mask,rewards,rewards+torch.mul(gamma,max_v))
        actions = torch.nn.functional.one_hot(actions.to(torch.int64), num_actions)
        q = torch.sum(torch.multiply(q_values,actions),dim=1)

        self.loss = torch.nn.functional.mse_loss(q_samp,q)
        return self.loss


    def add_optimizer(self):

        self.optimizer = torch.optim.Adam(self.q_network.parameters())


if __name__ == "__main__":
    logging.getLogger(
        "matplotlib.font_manager"
    ).disabled = True  # disable font manager warnings
    env = EnvTest((5, 5, 1))

    # exploration strategy
    exp_schedule = LinearExploration(
        env, config.eps_begin, config.eps_end, config.eps_nsteps
    )

    # learning rate schedule
    lr_schedule = LinearSchedule(config.lr_begin, config.lr_end, config.lr_nsteps)

    # train model
    model = Linear(env, config)
    model.run(exp_schedule, lr_schedule, run_idx=1)
