In [1]:
%matplotlib inline

import gym
import math
import matplotlib
import matplotlib.pyplot as plt
import torch
import random
from torchdiffeq import odeint
import numpy as np


    
from collections import namedtuple, deque
from itertools import count
torch.manual_seed(0)
random.seed(0)

In [2]:
env = gym.make('CartPole-v1')

is_ipython = 'inline' in matplotlib.get_backend()
if is_ipython:
    from IPython import display
    
plt.ion()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
print(torch.__version__)

cpu
1.13.0


In [3]:
import tensorflow as tf
if tf.test.gpu_device_name(): 
    print('Default GPU Device:{}'.format(tf.test.gpu_device_name()))
else:
    print("Please install GPU version of TF")

Please install GPU version of TF


In [4]:
Transition = namedtuple("Transition", ('state', 'action','next_state','reward'))

class ReplayMemory:
    def __init__(self, capacity: int):
        self._memory = deque([], maxlen=capacity)
        
    def __len__(self):
        return len(self._memory)
        
    def push(self, state, action, next_state, reward):
        self._memory.append(Transition(state, action, next_state, reward))
        
    def sample(self, batch_size):
        return random.sample(self._memory, batch_size)
    
class DQN(torch.nn.Module):
    def __init__(self, n_observation, n_actions):
        super(DQN, self).__init__()
        self.layer1 = torch.nn.Linear(n_observation, 128)
        self.layer2 = torch.nn.Linear(128, 128)
        self.layer3 = torch.nn.Linear(128, n_actions)
        
    def forward(self, x):
        x = torch.nn.functional.relu(self.layer1(x))
        x = torch.nn.functional.relu(self.layer2(x))
        x = torch.nn.functional.relu(self.layer3(x))
        return x

In [5]:
gamma = 0.99
batch_size = 128

tau = 0.005
epsilon_start = 0.95
epsilon_end = 0.05
espilon_decay = 2500
learning_rate = 1e-4

state, info = env.reset()
n_observations = len(state)
n_actions = env.action_space.n

policy_net = DQN(n_observations, n_actions).to(device)
target_net = DQN(n_observations, n_actions).to(device)
target_net.load_state_dict(policy_net.state_dict())

steps_done = 0
episode_durations = [] 
episode_epsilon_end = []
episode_training_error = []
replay_memory = ReplayMemory(1000)
#num_episodes = 600 if torch.cuda.is_available() else 50
num_episodes = 1500

optimizer = torch.optim.AdamW(policy_net.parameters(), lr=learning_rate, amsgrad=True)

In [6]:
def select_action(state):
    global steps_done
    global eps_threshold
    eps_threshold = epsilon_end + (epsilon_start - epsilon_end) * math.exp(-1. * steps_done / espilon_decay)
    if steps_done % 1000 == 0 and False:
        print(f"eps: {eps_threshold} steps: {steps_done}")
    steps_done += 1
    with torch.no_grad():
        if random.random() > eps_threshold:
            action = policy_net(state).max(1)[1].view(1, 1)
        else:
            action = torch.tensor([[env.action_space.sample()]], device=device, dtype=torch.long)
    return action

def plot_durations(show_result=False):
    durations = torch.tensor(episode_durations, dtype=torch.float)
    
    plt.figure(1)
    plt.clf()
    plt.subplot(121)
    plt.title('Lenght of an episode')
    plt.xlabel('Episode')
    plt.ylabel('Duration')
    plt.plot(durations.numpy())
    
    if len(durations) >= 100:
        means = durations.unfold(0, 100, 1).mean(1).view(-1)
        means = torch.cat((torch.zeros(99), means))
        plt.plot(means.numpy())
        
    plt.subplot(122)
    plt.title('Training Parameters')
    plt.xlabel('Episode')
    plt.ylabel('Epsilon')
    plt.plot(episode_epsilon_end, label="Epsilon")
    plt.gca().twinx().plot(episode_training_error, color = 'r', label="Training Error")
    plt.legend()
    plt.tight_layout()
    
    plt.pause(0.01)
    if is_ipython:
        if not show_result:
            display.display(plt.gcf())
            display.clear_output(wait=True)
        else:
            display.display(plt.gcf())

            
def optimize_model():
    if len(replay_memory) < batch_size:
        return
    transitions = replay_memory.sample(batch_size)
    batch = Transition(*zip(*transitions))
    
    non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), 
                                  device=device, dtype=torch.bool)
    non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
    
    state_batch = torch.cat(batch.state)
    action_batch = torch.cat(batch.action)
    reward_batch = torch.cat(batch.reward)
    
    # Compute Q(s_t, a) - the model computes Q(s_t), then we select the
    # columns of actions taken. These are the actions which would've been taken
    # for each batch state according to policy_net
    state_action_values = policy_net(state_batch).gather(1, action_batch)
    
    # Compute V(s_{t+1}) for all next states.
    # Expected values of actions for non_final_next_states are computed based
    # on the "older" target_net; selecting their best reward with max(1)[0].
    # This is merged based on the mask, such that we'll have either the expected
    # state value or 0 in case the state was final.
    next_state_values = torch.zeros(batch_size, device=device)
    with torch.no_grad():
        next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
    
    expected_state_action_values = (next_state_values * gamma) + reward_batch
    
    criterion = torch.nn.SmoothL1Loss()
    loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
    
    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
    optimizer.step()
    return loss.item()

In [7]:
def run_training(num_episodes):
    for i_episode in range(num_episodes):
        state, info = env.reset()
        state = torch.tensor(state, dtype=torch.float32, device=device).unsqueeze(0)
        for t in count():
            action = select_action(state)
            observation, reward, terminated, truncated, _ = env.step(action.item())
            reward = torch.tensor([reward], device=device)
            done = terminated or truncated

            if terminated:
                next_state = None
            else:
                next_state = torch.tensor(observation, dtype=torch.float32, device=device).unsqueeze(0)

            replay_memory.push(state, action, next_state, reward)
            state = next_state

            loss_value = optimize_model()

            target_net_state_dict = target_net.state_dict()
            policy_net_state_dict = policy_net.state_dict()

            for key in policy_net_state_dict:
                target_net_state_dict[key] = policy_net_state_dict[key]*tau + target_net_state_dict[key]*(1 - tau)
            target_net.load_state_dict(target_net_state_dict)

            if done:
                episode_durations.append(t + 1)
                episode_epsilon_end.append(eps_threshold)
                episode_training_error.append(loss_value)
                plot_durations()
                break

    print('Complete')
    plot_durations(show_result=True)
    plt.ioff()
    plt.show()

In [8]:
#run_training(num_episodes)

In [9]:
class ODENet(torch.nn.Module):
    def __init__(self,period):
        super(ODENet, self).__init__()
        self.net = torch.nn.Sequential(
            torch.nn.Linear(1, 50), 
            torch.nn.Tanh(),
            torch.nn.Linear(50, 1))
    
        for m in self.net.modules():
            if isinstance(m, torch.nn.Linear):
                torch.nn.init.normal_(m.weight, mean=0, std=0.1)
                torch.nn.init.constant_(m.bias, val=0)
            
    def forward(self, t, y):
        return self.net(y)

In [12]:
batch_size = 32
period_lenght = 10
step_size = 5

func = ODENet(1).to(device)
optimizer = torch.optim.RMSprop(func.parameters(), lr=1e-4)
t = torch.linspace(-100, 100, 201).to(device)
y_true = torch.linspace(-100, 100, 201).to(device) ** 2 



def get_batch(y, t, period_lenght):
    
    s = torch.from_numpy(np.random.choice(np.arange(y.size()[0] - period_lenght, dtype=np.int64),
                                          size=batch_size, replace=False))
    y0_batch2 = torch.reshape(y[s], (batch_size,1, 1))
    y_batch2 = torch.stack([y[s + i] for i in range(0, period_lenght, step_size)], dim=0)
    y_batch2 = torch.reshape(y_batch2,(period_lenght,batch_size, 1, 1) )

    
    idx = int(random.uniform(0, y.size()[0] - period_lenght))
    y0_batch = torch.reshape(y[idx], (1,))
    y_batch = y[idx:idx+period_lenght:step_size]
    t_batch = t[idx:idx+period_lenght:step_size]
    return y0_batch.to(device), y_batch.to(device), t_batch.to(device)

loss_list = []
y_pred_previous = y_true

epochs = 100000
for itr in range(1, epochs):
    optimizer.zero_grad()
    y0_batch, y_batch, t_batch = get_batch(y_true, t, period_lenght)
    """
    print(y0_batch.shape)
    print(t_batch.shape)
    print(y_batch.shape)
    print(y_pred.shape)
    """

    y_pred = odeint(func, y0_batch, t_batch).to(device)
    

    loss = torch.mean(torch.abs(y_pred - y_batch))
    loss.backward()
    optimizer.step()
    
    if itr % 50 == 0:
        y_pred = odeint(func, torch.reshape(y_true[0], (1,)), t)
        loss = torch.mean(torch.abs(y_pred - y_true))
        loss_list.append(loss.item())
        
        with torch.no_grad():
            plt.figure(1)
            plt.clf()
            plt.plot(t.cpu().numpy(), y_pred.cpu().numpy(), label="pred")
            plt.plot(t.cpu().numpy(), y_pred_previous.cpu().numpy(), label="prev pred")
            plt.plot(t.cpu().numpy(), y_true.cpu().numpy(), label="true")
            plt.legend()
            plt.title(itr)
            display.display(plt.gcf())
            if itr < epochs - 1:
                display.clear_output(wait=True)
            
        y_pred_previous = y_pred

RuntimeError: shape '[10, 32, 1, 1]' is invalid for input of size 64

In [None]:
plt.figure(1)
plt.plot(loss_list)
plt.title("Loss")
plt.show()

In [None]:
y_true = torch.linspace(-100, 100, 201) ** 2 


In [None]:
y.reshape(10, 32, 1, 2).shape

In [None]:
"""
import jax.numpy as jnp
from jax import jit, grad, vmap
from jax.experimental.ode import odeint
from itertools import zip_longest
import numpy.random as npr

def mlp(params, inputs):
    for w, b in params:
        outputs = jnp.dot(inputs, w) + b
        inputs = jnp.tanh(outputs)
    return outputs

def nn_dynamics(state, time, params):
    state_and_time = jnp.hstack([state, jnp.array(time)])
    return mlp(params, state_and_time)

def odenet(params, inputs):
    start_and_end_times = jnp.array([0.0, 1.0])
    init_state, final_state = odeint(nn_dynamics, inputs,
                                     start_and_end_times, params)
    return final_state

def resnet(params, inputs, depth):
    for i in range(depth):
        outputs = mlp(params, inputs) + inputs
    return outputs

def resnet_squared_loss(params, inputs,  targets):
    preds = resnet(params, inputs, resnet_depth)
    return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def init_random_params(scale, layer_sizes, rng=npr.RandomState(0)):
    params = []
    for m, n in zip(layer_sizes[:-1], layer_sizes[1:]):
        params += [((scale * rng.randn(m, n)), scale * rng.randn(n))]
    return params

def odenet_loss(params, inputs, targets):
    preds = batched_odenet(params, inputs)
    return jnp.mean(jnp.sum((preds - targets)**2, axis=1))

def ode_net_update(params, inputs, targets, step_size):
    grads = grad(odenet_loss)(params, inputs, targets)
    updates = []
    for (w, b), (dw, db) in zip(params, grads):
        updates += [(w - step_size*dw, b - step_size*db)]
    return updates

@jit
def resnet_update(params, inputs, targets, step_size):
    grads = grad(resnet_squared_loss)(params, inputs, targets)
    update = []
    for (w, b), (dw, db) in zip(params, grads):
        update += [(w - step_size * dw, b - step_size * db)]
    return update

# Toy 1D dataset
inputs = jnp.reshape(jnp.linspace(-2.0, 2.0, 10), (10, 1))
fine_inputs = jnp.reshape(jnp.linspace(-3.0, 3.0, 100), (100,1))
targets = inputs**3 + 0.1 * inputs
batched_odenet = vmap(odenet, in_axes=(None, 0))
odenet_layer_sizes = [2, 20, 1]

# Hyperparameters
layer_sizes = [1, 20, 1]
param_scale = 1.0
step_size = 0.01
train_iters = 1000
resnet_depth = 3

# Init and train
resnet_params = init_random_params(param_scale, layer_sizes)
for i in range(train_iters):
    resnet_params = resnet_update(resnet_params, inputs, targets, step_size)

odenet_params = init_random_params(param_scale, odenet_layer_sizes)

for i in range(train_iters):
    odenet_params = ode_net_update(odenet_params, inputs,
                                   targets, step_size)


fig = plt.figure(figsize=(6, 4), dpi=150)
ax = fig.gca()
ax.scatter(inputs, targets, lw=0.5, color="green")
ax.plot(fine_inputs, resnet(resnet_params, fine_inputs, resnet_depth), lw=0.5, color='blue')
ax.plot(fine_inputs, batched_odenet(odenet_params, fine_inputs), lw=0.5, color='red')
ax.set_xlabel('input')
ax.set_ylabel('output')
plt.legend(('Resnet predictions', 'ODE Net predictions'))
plt.show()
plt.close()
"""

In [None]:
device