In [None]:
import logging
import sys
import datetime
import os 
def set_logger(logger):
    formatter = logging.Formatter(
        "%(asctime)s - %(filename)s:%(lineno)s - %(levelname)s - %(message)s"
    )

    
    logger.setLevel(logging.INFO)
    logger.handlers = []

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.addHandler(handler)

    current_time = datetime.datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join("runs", current_time)

    os.makedirs(logdir, exist_ok=True)
    filename = os.path.join(logdir, "run.log")
    handler = logging.FileHandler(filename)
    handler.setLevel(logging.INFO)
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    return logger, logdir


logger = logging.getLogger()
logdir = set_logger(logger)

logger.info("Running RCRC Reward Prediction")


import argparse

import numpy as np

import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Beta
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
#from environment import make_single_env
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser(description='Train a RCRC to predict rewards for the CarRacing-v0')
parser.add_argument('--action-repeat', type=int, default=8, metavar='N', help='repeat action in N frames (default: 8)')
parser.add_argument('--img-stack', type=int, default=4, metavar='N', help='stack N image in a state (default: 4)')
parser.add_argument('--seed', type=int, default=np.random.randint(np.int32(2**31-1)), metavar='N', help='random seed (default: 0)')
parser.add_argument('--render', action='store_true', help='render the environment')
parser.add_argument('--tb', action='store_true', help='use tb')
parser.add_argument(
    '--log-interval', type=int, default=10, metavar='N', help='interval between training status logs (default: 10)')
args = parser.parse_args()

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
torch.manual_seed(args.seed)
if use_cuda:
    torch.cuda.manual_seed(args.seed)

transition = np.dtype([('true', np.float64, (513,)),
                       ('pred', np.float64, (513,)),])




class Env(gym.Wrapper):
    """
    Environment wrapper for CarRacing 
    """

    def __init__(self, env, resize=False, img_stack=3, action_repeat=8):
        super(Env, self).__init__(env)
        self.env = env
        #self.env.seed(args.seed)
        self.reward_threshold = self.env.spec.reward_threshold
        self.resize = resize
        self.img_stack = img_stack
        self.action_repeat = action_repeat

    def reset(self):
        self.counter = 0
        self.av_r = self.reward_memory()

        self.die = False
        img_rgb = self.env.reset()
        img_gray = self.rgb2gray(img_rgb)
        if self.resize:
            img_gray = rsz(img_gray, (64,64))
        self.stack = [img_gray] * self.img_stack  # four frames for decision
        out_img_stack = np.array(self.stack).astype(np.float64) 
        #out_img_stack = np.interp(out_img_stack, (out_img_stack.min(), out_img_stack.max()), (0, 255))
        #out_img_stack = (out_img_stack / out_img_stack.max()) * 255 
        #out_img_stack = out_img_stack.astype(np.uint8).transpose(1,2,0)
        return out_img_stack

    def step(self, action):
        
        total_reward = 0
        for i in range(self.action_repeat):
            img_rgb, reward, die, _ = self.env.step(action)
            # don't penalize "die state"
            if die:
                reward += 100
            # green penalty
            if np.mean(img_rgb[:, :, 1]) > 185.0:
                reward -= 0.05
            total_reward += reward
            # if no reward recently, end the episode
            done = True if self.av_r(reward) <= -0.1 else False
            if done or die:
                break
        img_gray = self.rgb2gray(img_rgb)
        if self.resize:
            img_gray = rsz(img_gray, (64,64))
        self.stack.pop(0)
        self.stack.append(img_gray)
        assert len(self.stack) == self.img_stack
        if done or die:
            done = True
        out_img_stack = np.array(self.stack).astype(np.float64) 
        #out_img_stack = np.interp(out_img_stack, (out_img_stack.min(), out_img_stack.max()), (0, 255))
        #out_img_stack = (out_img_stack / out_img_stack.max()) * 255 
        #out_img_stack = out_img_stack.astype(np.uint8).transpose(1,2,0)
        return out_img_stack, total_reward, done, die

    def render(self, *arg):
        self.env.render(*arg)

    @staticmethod
    def rgb2gray(rgb, norm=True):
        # rgb image -> gray [0, 1]
        gray = np.dot(rgb[..., :], [0.299, 0.587, 0.114])
        if norm:
            # normalize
            gray = gray / 128. - 1.
        return gray

    @staticmethod
    def reward_memory():
        # record reward for last 100 steps
        count = 0
        length = 100
        history = np.zeros(length)

        def memory(reward):
            nonlocal count
            history[count] = reward
            count = (count + 1) % length
            return np.mean(history)

        return memory

class Conv(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 32, stride=2)
        self.conv2 = nn.Conv2d(32, 64, 8, stride=2)
        self.conv3 = nn.Conv2d(64, 128, 2, stride=2)

    def forward(self, image_stack):
        x = self.conv1(image_stack)
        x = self.conv2(x)
        x = self.conv3(x)
        return x

    


def init_W(n, m):
    weight = torch.normal(mean=torch.zeros((n, m)), std=torch.ones((n, m)))

    N = n * m
    p = int(0.2 * N)

    u, s, v = torch.svd(weight, compute_uv=True)
    s_ = 0.95 * s / s.max()

    weight = u * s_ * v.t()
    indices = np.random.choice(N, p)
    for i in indices:
        a = i // n
        b = i - a * n
        weight[a, b] = 0
    return weight


class FixedRandomModel(nn.Module):
    def __init__(self, alpha):
        super().__init__()
        self.conv = Conv()
        self.W_in = nn.Linear(515, 512, bias=False)
        self.W = nn.Linear(512, 512, bias=False)
        self.W.weight.data = init_W(512, 512)
        self.x_esn = None
        self.alpha = alpha

    def forward(self, obs, prev_action):
        B = obs.shape[0]
        x_conv = self.conv(obs)
        x_conv_flat = x_conv.view(B, -1)
        x_esn_input = torch.cat((x_conv_flat, prev_action), dim=1)
        
        if self.x_esn is None or self.x_esn.shape[0] != B:
            x_esn = torch.tanh(self.W_in(x_esn_input))
        else:
            x_hat = torch.tanh(self.W_in(x_esn_input) + self.W(self.x_esn))
            x_esn = (1 - self.alpha) * self.x_esn + self.alpha * x_hat
        self.x_esn = x_esn
        return (x_esn_input, x_esn)


class WM(nn.Module):
    def __init__(self, model):
        super(WM, self).__init__()
        self.model = model
        for p in self.model.parameters():
            p.requires_grad = False
        self.future = nn.Sequential(nn.Linear(1028, 513))
        
        
    def forward(self, obs, action, conv_features = False):
        if conv_features:
            B = obs.shape[0]
            x_conv = self.model.conv(obs)
            x_conv_flat = x_conv.view(B, -1)
            return x_conv_flat
        x_esn_input, x_esn = self.model(obs, action)
        B = obs.shape[0]
        S = torch.cat((x_esn_input, x_esn, torch.ones((B, 1)).double().to(device) ), dim=1)
        future = self.future(S)
        return future
    


class Agent():
    """
    Agent for training
    """
    buffer_capacity, batch_size = 2000, 128

    def __init__(self):
        self.training_step = 0
        self.fixed_model = FixedRandomModel(0.5).double().to(device)
        self.wm_model = WM(self.fixed_model).double().to(device)
        self.buffer = np.empty(self.buffer_capacity, dtype=transition)
        self.counter = 0
        self.loss = f.mse_loss # torch.nn.BCEWithLogitsLoss()
        self.optimizer = optim.Adam(self.wm_model.parameters(), lr=1e-3)

    def save_param(self):
        torch.save(self.net.state_dict(), 'param/ppo_net_params.pkl')

    def store(self, transition):
        self.buffer[self.counter] = transition
        self.counter += 1
        if self.counter == self.buffer_capacity:
            self.counter = 0
            return True
        else:
            return False

    def update(self):
        self.training_step += 1

        r_true = torch.tensor(self.buffer['true'], dtype=torch.double).to(device).view(-1, 516)
        r_pred = torch.tensor(self.buffer['pred'], dtype=torch.double).to(device).view(-1, 516)
        
        self.loss_l = []

        #for _ in range(5):
        for index in BatchSampler(SubsetRandomSampler(range(self.buffer_capacity)), self.batch_size, False):

            loss = self.loss(r_pred[index], r_true[index])
            self.loss_l.append(loss)
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
                


if __name__ == "__main__":
    agent = Agent()
    env = gym.make("CarRacing-v0")
    env = Env(env, resize = True)
    parameters = sum(p.numel() for p in agent.wm_model.parameters())
    train_parameters = sum(p.numel() for p in agent.wm_model.parameters() if p.requires_grad)
    logger.info("Total Parameters : %s " % parameters)
    logger.info("Trainable Params : %s" % train_parameters)
    logger.info(agent.wm_model)
    
    if args.tb:
        writer = SummaryWriter(log_dir="./tb/")
    min_loss = 1000000
    for i_ep in range(100000):
        score = 0
        state = env.reset()
        agent.wm_model.model.x_esn = None
        
        for t in range(1000):
            action = env.action_space.sample()
            preds = agent.wm_model(state, action)
            state, reward, done, die = env.step(action * np.array([2., 1., 1.]) + np.array([-1., 0., 0.]))
            trues = agent.wm_model(state, action, conv_features=True)
            
            true_feautres = torch.cat((trues, torch.Tensor(reward).unsqueeze(0).double().to(device)), 
                                      dim=1)
            if args.render:
                env.render()
            if agent.store((true_feautres, preds)):
                logger.info('updating')
                agent.update()
            
            if done or die:
                break
        loss = agent.loss_l
        
        if i_ep % args.log_interval == 0:
            if args.tb:
                writer.add_scalar(
                    "loss_avg", np.mean(loss), global_step=i_ep
                )
                
            logger.info('Ep {}\tLoss AVG: {:.2f}\t'.format(i_ep, np.mean(loss)))
            if np.mean(loss) < min_loss:
                min_loss = np.mean(loss)
                agent.save_param()
                logger.info("Saving a new model, min avg loss is {}".format(min_loss))
            
        #if np.mean(loss) < 1e-7:
        #    logger.info("Solved!")
        #    break

