In [1]:
import torch as th
import higher
from active_critic.model_src.whole_sequence_model import CriticSequenceModel, WholeSequenceModelSetup, WholeSequenceModel
from active_critic.model_src.transformer import (
    ModelSetup, generate_square_subsequent_mask)
from active_critic.utils.pytorch_utils import calcMSE
import copy
th.manual_seed(0)

def make_wsm_setup(seq_len, d_output, weight_decay, device='cuda'):
    wsm = WholeSequenceModelSetup()
    wsm.model_setup = ModelSetup()
    seq_len = seq_len
    d_output = d_output
    wsm.model_setup.d_output = d_output
    wsm.model_setup.nhead = 8
    wsm.model_setup.d_hid = 200
    wsm.model_setup.d_model = 200
    wsm.model_setup.nlayers = 5
    wsm.model_setup.seq_len = seq_len
    wsm.model_setup.dropout = 0
    wsm.lr = 1e-4
    wsm.model_setup.device = device
    wsm.optimizer_class = th.optim.AdamW
    wsm.optimizer_kwargs = {'weight_decay':weight_decay}
    return wsm

class module_parameter(th.nn.Module):
    def __init__(self) -> None:
        super().__init__()

seq_len = 10
device = 'cuda'
weight_decay = 1e-2

batch_size = 2


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def high_step(critic:CriticSequenceModel, plan_enc, verbose, inf_opt_steps, inf_opt_lr, obsvs, acts, rewards, goal_rewards, get_critic_inpt):
    init_plan_enc = copy.deepcopy(plan_enc)
    meta_plan_enc = copy.deepcopy(plan_enc)
    inpt_optim = th.optim.SGD(meta_plan_enc.parameters(), lr=inf_opt_lr)
    if critic.model is None:
        with th.no_grad():
            test_inpt = get_critic_inpt(obsvs, acts, plan_enc.param)
            critic.forward(test_inpt)

    for step in range(inf_opt_steps):
        with higher.innerloop_ctx(critic, critic.optimizer) as (higher_critic, higher_optimizer):
            with higher.innerloop_ctx(meta_plan_enc, inpt_optim) as (higher_meta_plan_enc, higher_inpt_optim):
                meta_inpt = get_critic_inpt(obsvs, acts, higher_meta_plan_enc.param)
                init_meta_inpt = get_critic_inpt(obsvs, acts, init_plan_enc.param)
                meta_res_1 = higher_critic.forward(meta_inpt)
                forward_result = higher_critic.forward(init_meta_inpt)
                forward_loss = calcMSE(forward_result, rewards)
                meta_loss_1 = calcMSE(meta_res_1, goal_rewards)
                higher_inpt_optim.step(meta_loss_1)

                meta_inpt = get_critic_inpt(obsvs, acts, higher_meta_plan_enc.param)
                meta_res_2 = higher_critic.forward(meta_inpt)
                meta_loss_2 = calcMSE(goal_rewards, meta_res_2)

                #higher_optimizer.step(loss2)

                meta_plan_enc.load_state_dict(higher_meta_plan_enc.state_dict())
                grad_of_grads = th.autograd.grad(
                    meta_loss_2 + forward_loss, higher_critic.parameters(time=0))
                if verbose:
                    print(f'{step}___________________________________ meta')
                    print(f'loss2: {meta_loss_2}')
                    print(f'forward: {forward_loss}')
                    print(f'total: {meta_loss_2 + forward_loss}')
                    print(f'diff: {meta_loss_1-meta_loss_2}')
                    #print(grad_of_grads[0])
        critic_param_list = list(critic.parameters())
        max_grad = 0
        for index in range(len(grad_of_grads)):
            critic_param_list[index].grad = grad_of_grads[index]
            abs_grad = th.abs(grad_of_grads[index]).max()
            max_grad = max(max_grad, abs_grad)
        if verbose:
            print(max_grad)
        critic.optimizer.step()

In [3]:
def get_planner_inpt(acts, obsvs):
    return th.cat((acts, obsvs), dim=-1)

def get_actor_inpt(plans, obsvs):
    return th.cat((plans, obsvs), dim=-1)

def get_critic_inpt(obsvs, acts, plan):
    return th.cat((obsvs, acts, plan), dim=-1)

In [4]:
def critic_step(
        critic:CriticSequenceModel, 
        planner:WholeSequenceModel, 
        verbose, 
        inf_opt_steps, 
        inf_opt_lr, 
        obsvs, 
        acts, 
        rewards, 
        goal_rewards,
        get_critic_inpt, 
        get_planner_inpt,
        expert_trjs,
        ):
    planner_inpt = get_planner_inpt(acts=acts, obsvs=obsvs)
    with th.no_grad():
        plans = planner.forward(planner_inpt)
        plans[expert_trjs] = 0
    plans_param = th.nn.parameter.Parameter(plans.detach())
    plans_module = module_parameter()
    plans_module.register_parameter('param', param=plans_param)
    high_step(critic=critic, plan_enc=plans_module, verbose=verbose, inf_opt_steps=inf_opt_steps, inf_opt_lr=inf_opt_lr, obsvs=obsvs, acts=acts, rewards=rewards, goal_rewards=goal_rewards, get_critic_inpt=get_critic_inpt)


In [5]:
def actor_step(actor:WholeSequenceModel, planner:WholeSequenceModel, obsvs:th.Tensor, acts:th.Tensor, expert_trjs:th.Tensor, get_planner_inpt, get_actor_inpt, verbose):

    planner_inpt = get_planner_inpt(acts=acts, obsvs=obsvs)
    plans = planner.forward(planner_inpt)
    plans[expert_trjs] = 0

    actor_input = get_actor_inpt(plans=plans, obsvs=obsvs)
    actor_result = actor.forward(actor_input)
    loss = ((actor_result.reshape(-1) - acts.reshape(-1))**2).mean()
    actor.optimizer.zero_grad()
    planner.optimizer.zero_grad()
    loss.backward()
    actor.optimizer.step()
    planner.optimizer.step()
    if verbose:
        print('actor________________________________________')
        print(loss)


In [6]:
obsvs = th.zeros([batch_size, seq_len, 3], dtype=th.float32, device=device)
obsvs[1] = 1
acts = th.rand([batch_size, seq_len, 2], dtype=th.float32, device=device)
rewards = th.zeros([batch_size, 1, 1], dtype=th.float32, device=device)
goal_rewards = th.ones_like(rewards)
expert_trjs = th.ones([batch_size], device=device, dtype=th.bool)
expert_trjs[0] = 1

wsm_planner_setup = make_wsm_setup(
        seq_len=seq_len, d_output=3, device=device, weight_decay=weight_decay)
planner = WholeSequenceModel(wsm_planner_setup)

wsm_actor_setup = make_wsm_setup(
        seq_len=seq_len, d_output=2, device=device, weight_decay=weight_decay)
actor = WholeSequenceModel(wsm_actor_setup)

wsm_critic_setup = make_wsm_setup(
        seq_len=seq_len, d_output=1, device=device, weight_decay=weight_decay)
critic = CriticSequenceModel(wsm_critic_setup)

In [7]:
for i in range(1000):
    verbose = i%30==0
    critic_step(
        critic=critic,
        planner=planner,
        verbose=verbose,
        inf_opt_lr=1e-2,
        inf_opt_steps=3,
        obsvs=obsvs,
        acts=acts,
        rewards=rewards,
        goal_rewards=goal_rewards,
        get_critic_inpt=get_critic_inpt,
        get_planner_inpt=get_planner_inpt,
        expert_trjs=expert_trjs
    )
    for j in range(10):
        actor_step(
            actor=actor,
            planner=planner,
            obsvs=obsvs,
            acts=acts,
            expert_trjs=expert_trjs,
            get_actor_inpt=get_actor_inpt,
            get_planner_inpt=get_planner_inpt,
            verbose=verbose
        )

0___________________________________ meta
loss2: 1.5033451318740845
forward: 0.05153676122426987
total: 1.5548819303512573
diff: 0.00046312808990478516
tensor(2.9788, device='cuda:0')
1___________________________________ meta
loss2: 1.3436613082885742
forward: 0.02607967145740986
total: 1.3697409629821777
diff: 0.0003428459167480469
tensor(4.1799, device='cuda:0')
2___________________________________ meta
loss2: 1.2100670337677002
forward: 0.011262440122663975
total: 1.2213294506072998
diff: 0.00029838085174560547
tensor(5.0468, device='cuda:0')
actor________________________________________
tensor(0.3334, device='cuda:0', grad_fn=<MeanBackward0>)
actor________________________________________
tensor(0.1347, device='cuda:0', grad_fn=<MeanBackward0>)
actor________________________________________
tensor(0.1115, device='cuda:0', grad_fn=<MeanBackward0>)
actor________________________________________
tensor(0.1280, device='cuda:0', grad_fn=<MeanBackward0>)
actor_______________________________

KeyboardInterrupt: 

In [None]:

for i in range(3000):
    res = actor.optimizer_step(inputs=obsvs, label=acts)
    if i%100==0:
        print(res)

In [None]:
actor.optimizer.state_dict()

In [None]:
obsvs = th.zeros([batch_size, seq_len, 3], dtype=th.float32, device=device)
obsvs[1] = 1
acts = th.rand([batch_size, seq_len, 2], dtype=th.float32, device=device)

actor = WholeSequenceModel(wsm_actor_setup)

for i in range(2000):
    verbose = i%100==0
    actor_step(actor=actor, planner=planner, obsvs=obsvs, acts=acts, expert_trjs=expert_trjs, get_planner_inpt=get_planner_inpt, get_actor_inpt=get_actor_inpt, verbose=verbose)

In [None]:
result = actor.forward(obsvs)

In [None]:
result

In [None]:
acts

In [None]:
acts

In [None]:
actor.optimizer.step()

In [None]:
list(actor.parameters())[0]