# Import Everything

In [None]:
""" Various auxiliary utilities """
import math
from os.path import join, exists
import torch
from torch import optim
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
from models import MDRNN,MDRNNCell, VAE, Controller,Dtild,HiddenVAE
from models.mdrnn import gmm_loss

import gym
import gym.envs.box2d
import random
import gym_minigrid
from gym_minigrid.wrappers import *
from gym_minigrid.window import Window
import cv2
from itertools import chain
import scipy.stats
import matplotlib.pyplot as plt
from sklearn import neighbors
from sklearn.metrics import mean_squared_error
from math import sqrt
from collections import namedtuple
from utils import dqn_utils

from torch.distributions.categorical import Categorical



from sklearn.metrics import confusion_matrix
import seaborn as sn

# All Q learning stuff

In [None]:
ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE = 1, 32, 256, 64, 64

Experience = namedtuple(
    'Experience',
    ('state', 'action', 'next_state', 'reward')
)

criterion = nn.MSELoss()


batch_size = 10
learning_rate = 0.0001
discount_rate = 0.99
gamma = 0.99
exploration_decay_rate = 0.001
memory_size = 100000
target_update = 10


exploration_rate = 0.8
max_exploration_rate = 0.8
min_exploration_rate = 0.1
device = 'cuda'

Controllerstrategy = dqn_utils.EpsilonGreedyStrategy(max_exploration_rate, min_exploration_rate, exploration_decay_rate)
Controlleragent = dqn_utils.Agent(Controllerstrategy,'controller', device)
Controllermemory = dqn_utils.ReplayMemory(memory_size)

#2. Initialise policy network with random weights

Controllerpolicy_net = dqn_utils.DQN(288, 32,6,dropout= True, dropout_prob= 0.2).to(device)

Controllertarget_net = dqn_utils.DQN(288, 32,6,dropout= True, dropout_prob= 0.2).to(device)
Controllertarget_net.load_state_dict(Controllerpolicy_net.state_dict())
Controllertarget_net.eval()
Controlleroptimizer = optim.Adam(params=Controllerpolicy_net.parameters(), lr=learning_rate)

# Other functions used by other things

In [None]:
def sample_continuous_policy(action_space, seq_len, dt):
    """ Sample a continuous policy.

    Atm, action_space is supposed to be a box environment. The policy is
    sampled as a brownian motion a_{t+1} = a_t + sqrt(dt) N(0, 1).

    :args action_space: gym action space
    :args seq_len: number of actions returned
    :args dt: temporal discretization

    :returns: sequence of seq_len actions
    """
    actions = [action_space.sample()]
    for _ in range(seq_len):
        daction_dt = np.random.randn(*actions[-1].shape)
        actions.append(
            np.clip(actions[-1] + math.sqrt(dt) * daction_dt,
                    action_space.low, action_space.high))
    return actions

def save_checkpoint(state, is_best, filename, best_filename):
    """ Save state in filename. Also save in best_filename if is_best. """
    torch.save(state, filename)
    if is_best:
        torch.save(state, best_filename)

def flatten_parameters(params):
    """ Flattening parameters.

    :args params: generator of parameters (as returned by module.parameters())

    :returns: flattened parameters (i.e. one tensor of dimension 1 with all
        parameters concatenated)
    """
    return torch.cat([p.detach().view(-1) for p in params], dim=0).cpu().numpy()

def unflatten_parameters(params, example, device):
    """ Unflatten parameters.

    :args params: parameters as a single 1D np array
    :args example: generator of parameters (as returned by module.parameters()),
        used to reshape params
    :args device: where to store unflattened parameters

    :returns: unflattened parameters
    """
    params = torch.Tensor(params).to(device)
    idx = 0
    unflattened = []
    for e_p in example:
        unflattened += [params[idx:idx + e_p.numel()].view(e_p.size())]
        idx += e_p.numel()
    return unflattened

def load_parameters(params, controller):
    """ Load flattened parameters into controller.

    :args params: parameters as a single 1D np array
    :args controller: module in which params is loaded
    """
    proto = next(controller.parameters())
    params = unflatten_parameters(
        params, controller.parameters(), proto.device)

    for p, p_0 in zip(controller.parameters(), params):
        p.data.copy_(p_0)

# Rollout Class

In [None]:
class RolloutGenerator(object):
    """ Utility to generate rollouts.

    Encapsulate everything that is needed to generate rollouts in the TRUE ENV
    using a controller with previously trained VAE and MDRNN.

    :attr vae: VAE model loaded from mdir/vae
    :attr mdrnn: MDRNN model loaded from mdir/mdrnn
    :attr controller: Controller, either loaded from mdir/ctrl or randomly
        initialized
    :attr env: instance of the CarRacing-v0 gym environment
    :attr device: device used to run VAE, MDRNN and Controller
    :attr time_limit: rollouts have a maximum of time_limit timesteps
    """
    def __init__(self, mdir, device, time_limit,number_goals,hiddengoals:bool, curiosityreward = bool):
        """ Build vae, rnn, controller and environment. """
        # Loading world model and vae
        vae_file, rnn_file, ctrl_file,Dtild_file,hiddenvae_file = \
            [join(mdir, m, 'best.tar') for m in ['vae', 'mdrnn', 'ctrl','dtild','hiddenvae']]

        assert exists(vae_file) and exists(rnn_file),\
            "Either vae or mdrnn is untrained."

        vae_state, rnn_state,hiddenvae_state = [
            torch.load(fname, map_location={'cuda:0': str(device)})
            for fname in (vae_file, rnn_file,hiddenvae_file)]

        for m, s in (('VAE', vae_state), ('MDRNN', rnn_state),('HiddenVAE',hiddenvae_state)):
            print("Loading {} at epoch {} "
                  "with test loss {}".format(
                      m, s['epoch'], s['precision']))

        self.vae = VAE(3, LSIZE).to(device)
        self.vae.load_state_dict(vae_state['state_dict'])
        
        self.HiddenVAE = HiddenVAE(256, LSIZE).to(device)
        self.HiddenVAE.load_state_dict(hiddenvae_state['state_dict'])

        self.mdrnn = MDRNNCell(LSIZE, ASIZE, RSIZE, 5).to(device)
        self.mdrnn.load_state_dict(
            {k.strip('_l0'): v for k, v in rnn_state['state_dict'].items()})
        
        self.mdrnnBIG = MDRNN(LSIZE, ASIZE, RSIZE, 5).to(device)
        self.mdrnnBIG.load_state_dict(rnn_state["state_dict"])
        
        
        self.controller = Controller(256, 256, 6).to(device)
        self.fmodel = Dtild(32,1, 32).to(device)
                
        # load controller if it was previously saved
        if exists(ctrl_file):
            ctrl_state = torch.load(ctrl_file, map_location={'cuda:0': str(device)})
            print("Loading Controller with reward {}".format(
                ctrl_state['reward']))
            self.controller.load_state_dict(ctrl_state['state_dict'])

        self.env = gym.make('MiniGrid-MultiRoom-N6-v0')
        
        self.device = device
        self.number_goals = number_goals
        self.time_limit = time_limit
        
        #self.model = neighbors.KNeighborsRegressor(n_neighbors = 5)        
        # We initialize the inner non parametric model
        
        self.rnn_state = rnn_state
        self.curiosityreward = curiosityreward
        self.hiddengoals = hiddengoals
        
        if hiddengoals:
            self.Controllerpolicy_net = dqn_utils.DQN(288, 256,6,dropout= True, dropout_prob= 0.2).to(device)
            self.Controllertarget_net = dqn_utils.DQN(288, 256,6,dropout= True, dropout_prob= 0.2).to(device)
        else:
            self.Controllerpolicy_net = dqn_utils.DQN(288, 32,6,dropout= True, dropout_prob= 0.2).to(device)
            self.Controllertarget_net = dqn_utils.DQN(288, 32,6,dropout= True, dropout_prob= 0.2).to(device)
        
        self.Controllertarget_net.load_state_dict(Controllerpolicy_net.state_dict())
        self.Controllertarget_net.eval()
        self.Controlleroptimizer = optim.Adam(params=Controllerpolicy_net.parameters(), lr=learning_rate)
        
        
        

    def rollout(self, params, render=False):
        """ Execute a rollout and returns minus cumulative reward.

        Load :params: into the controller and execute a single rollout. This
        is the main API of this class.

        :args params: parameters as a single 1D np array

        :returns: minus cumulative reward
        """
        # copy params into the controller
        if params is not None:
            load_parameters(params, self.fmodel)
        optimizer = optim.Adam(params= self.fmodel.parameters(), lr=0.0001)
        MDRNNoptimizer = torch.optim.RMSprop(self.mdrnnBIG.parameters(), lr=1e-3, alpha=.9)
        MDRNNoptimizer.load_state_dict(self.rnn_state["optimizer"])
        
        
        
        zstate_list = []
        obs = self.env.reset()
        
        
        obs = obs['image']
        expl_rate = 0.5
        print(obs.shape)
        # This first render is required !
        self.env.render()

                
        hidden = [
            torch.zeros(1, RSIZE).to(self.device)
            for _ in range(2)]
        
        z = self.tolatent(obs)
        
        i = 0
        
        #Bootstrapping
        while True:
           
            action = random.randrange(6)
            
            _,hidden,z,zh  = self.transform_obs_hidden(obs, hidden, action)
            
            obs, reward, done, _ = self.env.step(action)
            obs = obs['image']
            
            if self.hiddengoals:
                zstate_list.append(np.array(hidden[0].cpu().detach().numpy()))#if we use pure hidden
            else:
                zstate_list.append(np.array(z.cpu().detach().numpy()))#if we use latent_space
            
            i+=1
            #if render:
            self.env.render()
            
            print(i)
            
            
            if i > self.time_limit:
                break
                #return action,hidden,z_list
    
       
        
        #visitationarray = np.zeros((25,25))
        
        
        s = obs
        goal_achieved_list = []
        final_hidden_CS_list = []
        loss_list = []
        WM_loss= []
        rollout_reward = []
        
        #Goal Exploration
        for c in range(self.number_goals):
            zstate_list = np.array(zstate_list) 
            zstate_list = zstate_list.squeeze(1)
            kde = scipy.stats.gaussian_kde(zstate_list.T)
            
            z_goal = sampling_method(kde)
            z_goal = torch.tensor([z_goal],dtype = torch.float32).to(self.device) #controller requires both as tensors
            
            
            if not self.hiddengoals:
                z_goal_obs = self.vae.decoder(z_goal)
                z_goal_obs = z_goal_obs.reshape(7,7,3)
                z_goal_obs = np.array(z_goal_obs.cpu().detach())

                plt9 = plt.figure('Zgoal')
                plt.cla()
                sn.heatmap(z_goal_obs[:,:,0],cmap = 'Reds', annot=True,cbar = False).invert_yaxis()
            
            
            total_reward = 0
            total_loss = 0
            goal_loss = []
            
            
            
            scur_rollout = []
            snext_rollout = []
            r_rollout = []
            d_rollout = []
            act_rollout = []
            
            zstate_list = zstate_list[:,np.newaxis,:]
            zstate_list = zstate_list.tolist()
            
            
            for goalattempts in range(100):
                latent_mu,logsigma,z = self.tolatent(s)
                #visitationarray[self.env.agent_pos[0],self.env.agent_pos[1]] += 1
                state = torch.cat((torch.cat((hidden[0].detach(),z), dim=1),z_goal), dim=1)
                h = []
                print('C, goalattempt number', c,goalattempts)
                m = Controlleragent.select_action(torch.cat((hidden[0].detach(),z), dim=1), z_goal, Controllerpolicy_net, expl_rate)
                
                hmus,hsigmas,hlogpi, zt1 = self.predict_next(s, hidden,m) #gets mean, standard deviation and  pi, next latent of prediction of next latent obs
                _,hidden,z,zh = self.transform_obs_hidden(s,hidden,m) #gets next hidden , current latent obs, prediction of next latent obs
                
                
                if self.hiddengoals:
                    zstate_list.append(np.array(hidden[0].cpu().detach().numpy()))#if we use pure hidden
                else:
                    zstate_list.append(np.array(z.cpu().detach().numpy()))#if we use latent_space
                
                predicted_next_obs = self.vae.decoder(zh)
                predicted_next_obs = predicted_next_obs.reshape(7,7,3)
                p = np.array(predicted_next_obs.cpu().detach())
                
                
                #print('predicted obs', predicted_next_obs)
                plt5 = plt.figure('Predicted obs')
                plt.cla()
                sn.heatmap(p[:,:,0],cmap = 'Reds', annot=True,cbar = False).invert_yaxis()
                
                
                s,_,_,_ = self.env.step(m) #use the action sampled and see if we get to the goal we wanted or how close we got to the goal
                
                
                s = s['image']
                
                next_mu,next_logsigma,next_z = self.tolatent(s)
                next_state = torch.cat((torch.cat((hidden[0].detach(),next_z), dim=1),z_goal), dim=1)
                
                #print('Actual result', torch.tensor(s,dtype = torch.float32))
                plt6 = plt.figure('Actual obs')
                plt.cla()
                sn.heatmap(s[:,:,0],cmap = 'Reds', annot=True,cbar = False).invert_yaxis() 
                
                
                scur_rollout.append(np.array(z.cpu().detach()))
                snext_rollout.append(np.array(next_z.cpu().detach()))
                r_rollout.append([0.0])
                act_rollout.append([[np.float(m)]])
                d_rollout.append([0.0])
                

            
                self.env.render()
                
                
                
            
                
                Curiosityreward = gmm_loss(next_z.detach(), hmus, hsigmas, hlogpi)/33
                #floss = criterion(h[m],next_z.detach())
                total_loss += Curiosityreward
                
                
                if self.hiddengoals:
                    goal_loss.append(criterion(next_z.detach(),z_goal).item()) #how far away the achieved step is from the goal
                else:
                    goal_loss.append(gmm_loss(z_goal,next_mu,next_logsigma.exp(),torch.tensor([-1.0], dtype = torch.float32).to(self.device))/33)
                
                
                #if next z matches goal to a certain degree then get reward
                if goal_loss[-1] < 1.2: 
                    reward = 4.0
                else:
                    reward = 0.0
                    
                
                if self.curiosityreward:
                    reward = reward + Curiosityreward
                
                reward = torch.tensor([reward], device=device, requires_grad = False)
                total_reward += reward
                    
                Controllermemory.push(Experience(state.detach(), torch.tensor([m]).to(device), next_state.detach(), reward))
                
                if Controllermemory.can_provide_sample(batch_size):
                    
                    # retrieve experiences from batch
                    experiences = Controllermemory.sample(batch_size)
                    
                    states, actions, rewards, next_states = dqn_utils.extract_tensors(experiences)
                
                    # Get Q Values from policy and target network
                    Controllercurrent_q_values = dqn_utils.ActionQValues.get_current(Controllerpolicy_net, states, actions)
                    Controllernext_q_values = dqn_utils.ActionQValues.get_next(Controllertarget_net, next_states)
                    Controllertarget_q_values = (Controllernext_q_values * gamma) + rewards
                    # Calculate Loss and back propagate
                    Controllerloss = F.mse_loss(Controllercurrent_q_values, Controllertarget_q_values.unsqueeze(1))
                    Controlleroptimizer.zero_grad()
                    Controllerloss.backward()
                    Controlleroptimizer.step()
                    
                
                if c % target_update == 0:
                    print('Updating target network')
                    Controllertarget_net.load_state_dict(Controllerpolicy_net.state_dict())
            
                    
            expl_rate = min_exploration_rate + \
            (max_exploration_rate - min_exploration_rate) * np.exp(-exploration_decay_rate*c)
            
            mdrnnlosses = self.get_loss(torch.tensor(scur_rollout).to(self.device), torch.tensor(act_rollout).to(self.device), torch.tensor(r_rollout).to(self.device),
                              torch.tensor(d_rollout).to(self.device), torch.tensor(snext_rollout).to(self.device), include_reward = False)
            
            
           
            MDRNNoptimizer.zero_grad()
            mdrnnlosses['loss'].backward()
                
            MDRNNoptimizer.step()
            
            WM_loss.append(mdrnnlosses['loss'])
            
            
            if goalattempts % 10 == 0:
                self.mdrnn.load_state_dict(self.mdrnnBIG.state_dict())    
                
    
            loss_list.append(total_loss/(goalattempts+1))
            rollout_reward.append(total_reward)
            
            
            
            
            plot1 = plt.figure('Average Forward model loss')
            plt.plot(loss_list)
            plt7= plt.figure('WM_loss')
            plt.plot(WM_loss)
            plt4 = plt.figure('Distance to goal per step')
            plt.cla()
            plt.plot(goal_loss)
            rolloutrewardplot = plt.figure('Reward per rollout')
            plt.plot(self.rollout_reward)
            #plt8 = plt.figure('Visitation')
            #plt.cla()
            #sn.heatmap(visitationarray,cmap = 'Reds', annot=True,cbar = False) 
            
            
            plt.show() 
            
            #input('stop')
            
            
            
    def transform_obs_hidden(self,obs, hidden,m):
        obs = torch.tensor(obs.flatten(),dtype = torch.float32).unsqueeze(0).to(self.device)
        #print('obs shape from transform obs hidden: ', obs.shape)
       
        action =  torch.Tensor([[m]]).to(self.device)
        reconx, latent_mu, logsigma = self.vae(obs)
        
        

        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(latent_mu)
        
        hmus, hsigmas, hlogpi, _, _, next_hidden = self.mdrnn(action, z, tuple(hidden))
        
        
        hlogpi = hlogpi.squeeze()
        mixt = Categorical(torch.exp(hlogpi)).sample().item()
        
        
        zh = hmus[:, mixt, :]  + hsigmas[:, mixt, :] * torch.randn_like(hmus[:, mixt, :])
        
        
        
        return action.squeeze().cpu().numpy(), next_hidden,z,zh
    
    def gpcf(self,z_goal,h):
        
        output = []
        for action in range(len(h)):
            output.append(criterion(h[action],z_goal).item())
            
        return output.index(min(output))
    
   
    def get_loss(self, latent_obs, action, reward, terminal,
                 latent_next_obs, include_reward: bool):
        """ Compute losses.

        The loss that is computed is:
        (GMMLoss(latent_next_obs, GMMPredicted) + MSE(reward, predicted_reward) +
             BCE(terminal, logit_terminal)) / (LSIZE + 2)
        The LSIZE + 2 factor is here to counteract the fact that the GMMLoss scales
        approximately linearily with LSIZE. All losses are averaged both on the
        batch and the sequence dimensions (the two first dimensions).

        :args latent_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor
        :args action: (BSIZE, SEQ_LEN, ASIZE) torch tensor
        :args reward: (BSIZE, SEQ_LEN) torch tensor
        :args latent_next_obs: (BSIZE, SEQ_LEN, LSIZE) torch tensor

        :returns: dictionary of losses, containing the gmm, the mse, the bce and
            the averaged loss.
        """


        mus, sigmas, logpi, rs, ds = self.mdrnnBIG(action, latent_obs)
        gmm = gmm_loss(latent_next_obs, mus, sigmas, logpi)
        bce = F.binary_cross_entropy_with_logits(ds, terminal)
        if include_reward:
            mse = F.mse_loss(rs, reward)
            scale = LSIZE + 2
        else:
            mse = 0
            scale = LSIZE + 1
        loss = (gmm + bce + mse) / scale
        return dict(gmm=gmm, bce=bce, mse=mse, loss=loss)     
        
    def tolatent(self,obs):
        obs = torch.tensor(obs.flatten(),dtype = torch.float32).unsqueeze(0).to(self.device)
        #print('obs shape from transform obs hidden: ', obs.shape)
      
        reconx, latent_mu, logsigma = self.vae(obs)
       
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(latent_mu)
       
        return latent_mu,logsigma,z
    
    def tohiddenlatent(self,hidden):
        _,latent_mu,logsigma  =  self.HiddenVAE(hidden[0].detach())
              
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        zhidden = eps.mul(sigma).add_(latent_mu)
        
        return zhidden
    
    def predict_next(self,obs, hidden,m):
        obs = torch.tensor(obs.flatten(),dtype = torch.float32).unsqueeze(0).to(self.device)
        #print('obs shape from transform obs hidden: ', obs.shape)
       
        action =  torch.Tensor([[m]]).to(self.device)
        reconx, latent_mu, logsigma = self.vae(obs)
        
        sigma = logsigma.exp()
        eps = torch.randn_like(sigma)
        z = eps.mul(sigma).add_(latent_mu)
        
        hmus, hsigmas, hlogpi, _, _, next_hidden = self.mdrnn(action, z, tuple(hidden))
        
        hlogpi = hlogpi.squeeze()
        mixt = Categorical(torch.exp(hlogpi)).sample().item()
        zh = hmus[:, mixt, :]  + hsigmas[:, mixt, :] * torch.randn_like(hmus[:, mixt, :])
        
        return hmus,hsigmas,hlogpi, zh 