In [13]:
import gym
import random, math
import numpy as np
from collections import deque
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = (16, 10)

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical
torch.manual_seed(0)

import base64, io

# For visualization
from gym.wrappers.monitoring import video_recorder
from IPython.display import HTML
from IPython import display
import glob
from reinforce_rwd2go import reinforce_rwd2go, rollout, make_pref_dataset
from utils import pref_save, pref_load

from model import Policy
import pickle

%load_ext autoreload
%autoreload 2
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


device(type='cpu')

In [14]:
K = 100
SEED=0
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

env = gym.make('CartPole-v0')
print('observation space:', env.observation_space)
print('action space:', env.action_space)

observation space: Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)
action space: Discrete(2)


In [15]:
policy1 = Policy().to(device)
policy1.load_state_dict(torch.load(f"policy1_{K}.pth"))
policy2 = Policy().to(device)
policy2.load_state_dict(torch.load(f"policy2_{K}.pth"))

pref_data = pref_load(f"pref_data_{K}")

FileNotFoundError: [Errno 2] No such file or directory: 'pref_data_100'

In [16]:
import torch
import torch.nn.functional as F
import copy

beta      = 5.0       # how strongly to enforce preferences
lr        = 5e-3     # DPO learning rate
epochs    = 1

def copyPolicy(policy):
    pi_ref = copy.deepcopy(policy).to(device)

    for p in pi_ref.parameters():
        p.requires_grad = False

    return pi_ref

policy = copy.deepcopy(policy2).to(device)
pi_ref = copyPolicy(policy)

optimizer = torch.optim.Adam(policy.parameters(), lr=lr)

def trajectory_logprob(pi, states, actions):
    logp = torch.tensor(0., device=device)
    for s, a in zip(states, actions):
        s_t = torch.tensor(s, dtype=torch.float32, device=device)
        probs = pi(s_t.unsqueeze(0)).squeeze(0)      # shape [action_dim]
        logp += torch.log(probs[a])
    return logp

for epoch in range(1, epochs+1):
    total_loss = 0.0
    
    for s0, tau_plus, tau_minus in pref_data:
        # trajectory log-probs
        logp_pos = trajectory_logprob(policy,
                                      tau_plus ["states"],
                                      tau_plus ["actions"])
        logp_neg = trajectory_logprob(policy,
                                      tau_minus["states"],
                                      tau_minus["actions"])

        logp_ref_pos = trajectory_logprob(pi_ref,
                                      tau_plus ["states"],
                                      tau_plus ["actions"])

        logp_ref_neg = trajectory_logprob(pi_ref,
                                      tau_minus["states"],
                                      tau_minus["actions"])
        # DPO preference loss
        diff   = beta * (logp_pos - logp_ref_pos) - beta * (logp_neg - logp_ref_neg)
        total_loss += -F.logsigmoid(diff)
    
    total_loss = total_loss / len(pref_data)

    optimizer.zero_grad()
    total_loss.backward()
    optimizer.step()
    
    print(f"Epoch {epoch}/{epochs} — avg DPO loss: {total_loss:.4f}")

Epoch 1/1 — avg DPO loss: 0.6932


In [17]:
returns = []
eval_episodes = 100
for ep in range(eval_episodes):
    state, done, total_r = env.reset(), False, 0.0
    while not done:
        # choose greedy or stochastic—here greedy
        with torch.no_grad():
            s_t = torch.tensor(state, dtype=torch.float32, device=device)
            probs = policy(s_t.unsqueeze(0)).squeeze(0)
            action = torch.argmax(probs).item()
        state, r, done, _ = env.step(action)
        total_r += r
    returns.append(total_r)

mean_return = sum(returns) / len(returns)
print(f"Evaluation over {eval_episodes} episodes: mean return = {mean_return:.2f}")

Evaluation over 100 episodes: mean return = 165.62


In [18]:
returns = []
eval_episodes = 100
for ep in range(eval_episodes):
    state, done, total_r = env.reset(), False, 0.0
    while not done:
        # choose greedy or stochastic—here greedy
        with torch.no_grad():
            s_t = torch.tensor(state, dtype=torch.float32, device=device)
            probs = policy2(s_t.unsqueeze(0)).squeeze(0)
            action = torch.argmax(probs).item()
        state, r, done, _ = env.step(action)
        total_r += r
    returns.append(total_r)

mean_return = sum(returns) / len(returns)
print(f"Evaluation over {eval_episodes} episodes: mean return = {mean_return:.2f}")

Evaluation over 100 episodes: mean return = 151.09
