In [1]:
import numpy as np
import torch
from tensordict import TensorDict, TensorDictBase
from torchrl.envs import EnvBase
from torchrl.data.tensor_specs import OneHot, Composite, BoundedTensorSpec, Categorical, Unbounded, Binary
import itertools
from localcider.sequenceParameters import SequenceParameters
from torchrl.envs.utils import check_env_specs
import math

from torchrl.envs import TransformedEnv, FlattenObservation

In [2]:
class SwapVar(EnvBase): #dtype of tensors?
    def __init__(self):
        super().__init__()
        self.reference_state = 'MASASSSQRGRSGSGNFGGGRGGGFGGNDNFGRGGNFSGRGGFGGSRGGGGYGGSGDGYNGFGNDGSNFGGGGSYNDFGNYNNQSSNFGPMKGGNFGGRSSGGSGGGGQYFAKPRNQGGYGGSSSSSSYGSGRRF'
        self.seq_len = len(self.reference_state)
        self.alphabet = list(set(self.reference_state))
        self.initial_state = 'X'*self.seq_len
        self.alphabet.append('X')
        self.aa_to_idx = {aa: idx for idx, aa in enumerate(self.alphabet)}
        self.idx_to_aa = {idx: aa for idx, aa in enumerate(self.alphabet)}

        self.target = 0.8
        
        self._batch_size = torch.Size()
        
        #observation_spec = OneHot(n=len(self.alphabet),shape=(self.seq_len,len(self.alphabet)),dtype=torch.float,device='cpu')
        observation_spec = Binary(n=self.seq_len*len(self.alphabet),dtype=torch.float)
        self.observation_spec = Composite(observation=observation_spec)

        self.action_spec = Categorical(len(self.alphabet)-1,shape=torch.Size([1]),dtype=torch.int64)
        
        self.reward_spec = Unbounded(shape=torch.Size([1]))

    def fwd_onehot(self, sequence):
        indices = [self.aa_to_idx[aa] for aa in sequence]
        tensor_indices = torch.tensor(indices)
        one_hot = torch.nn.functional.one_hot(tensor_indices, num_classes=len(self.alphabet)).to(torch.float)
        return one_hot.flatten()

    def bwd_onehot(self, onehot):
        onehot = torch.reshape( onehot, (int(len(onehot)/len(self.alphabet)),len(self.alphabet)) )
        indices = torch.argmax(onehot, dim=1).tolist()
        sequence = ''.join([self.idx_to_aa[idx] for idx in indices])
        return sequence

    def reward(self, sequence):
        sequence = sequence[:self.n_step]
        if len(sequence) < self.seq_len:
            return 0, False
        else:
            SeqOb = SequenceParameters(sequence)
            score = SeqOb.get_kappa()
            return -abs(score-self.target), True

    def _reset(self, tensordict):
        if tensordict is not None:
            action = tensordict.get('action')
        else:
            action = self.action_spec.rand()
        #action = self.action_spec.rand()
        self.n_step = 0
        observation = self.fwd_onehot(self.initial_state)
        out_td = TensorDict({'observation': observation, 'action': action},batch_size=torch.Size())
        return out_td

    def _step(self, tensordict: TensorDictBase):
        #print(self.n_step)
        action = tensordict.get('action')
        m = self.alphabet[action]
        #print(m)
        state = list( self.bwd_onehot(tensordict.get('observation').clone()) )
        state[self.n_step] = m
        #print(state)
        state = ''.join(state)
        #print(state)
        self.n_step += 1
        reward, done = self.reward(state)

        next_observation = self.fwd_onehot(state)
        
        td = TensorDict(
            {
                'observation': next_observation,
                'reward': torch.tensor([reward]).float(),
                'done': torch.tensor([done]).bool()
            },
            batch_size=torch.Size()
        )
        return td

    def _set_seed(self, seed: int):
        np.random.seed(seed)
        torch.manual_seed(seed)

In [3]:
env = SwapVar()

In [4]:
env.rollout(5)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([5, 1755]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([5]),
            device=None,
            is_shared=False),
        observation: Tensor(shape=torch.Size([5, 1755]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([5, 1]), device=cpu, dtype=torch.bool, is_shared=Fa

In [70]:
check_env_specs(env)

2024-10-25 18:43:34,488 [torchrl][INFO] check_env_specs succeeded!


In [5]:
from torch import nn
from tensordict.nn.distributions import NormalParamExtractor
from tensordict.nn import TensorDictModule
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
from torchrl.collectors import SyncDataCollector
from torchrl.objectives.value import GAE
from torchrl.objectives import ClipPPOLoss
from torchrl.data.replay_buffers import ReplayBuffer
from torchrl.data.replay_buffers.samplers import SamplerWithoutReplacement
from torchrl.data.replay_buffers.storages import LazyTensorStorage

from collections import defaultdict
from tqdm import tqdm

In [6]:
num_cells = 256  # number of cells in each layer i.e. output dim.
lr = 3e-4
max_grad_norm = 1.0

In [7]:
frames_per_batch = 1000
# For a complete training, bring the number of frames up to 1M
total_frames = 10_000

In [8]:
sub_batch_size = 64  # cardinality of the sub-samples gathered from the current data in the inner loop
num_epochs = 10  # optimisation steps per batch of data collected
clip_epsilon = (
    0.2  # clip value for PPO loss: see the equation in the intro for more context.
)
gamma = 0.99
lmbda = 0.95
entropy_eps = 1e-4

In [20]:
device = 'cpu'

actor_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(2*env.action_spec.shape[-1], device=device),
    NormalParamExtractor(),
)

policy_module = TensorDictModule(
    actor_net, in_keys=["observation"], out_keys=["loc", "scale"]
)

actor = ProbabilisticActor(
    module=policy_module,
    spec=env.action_spec,
    in_keys=["loc", "scale"],
    distribution_class=TanhNormal,
    distribution_kwargs={"min": 0.0, "max": 1.0},
    return_log_prob=True,
)

print("Running policy:", policy_module(env.reset()))

Running policy: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        loc: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        observation: Tensor(shape=torch.Size([1755]), device=cpu, dtype=torch.float32, is_shared=False),
        scale: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


In [21]:
value_net = nn.Sequential(
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(num_cells, device=device),
    nn.Tanh(),
    nn.LazyLinear(1, device=device),
)

value_module = ValueOperator(
    module=value_net,
    in_keys=["observation"],
)

print("Running value:", value_module(env.reset()))

Running value: TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.int64, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([1755]), device=cpu, dtype=torch.float32, is_shared=False),
        state_value: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


In [22]:
collector = SyncDataCollector(
    env,
    policy_module,
    frames_per_batch=frames_per_batch,
    total_frames=total_frames,
    split_trajs=False,
    device=device,
)

In [23]:
replay_buffer = ReplayBuffer(
    storage=LazyTensorStorage(max_size=frames_per_batch),
    sampler=SamplerWithoutReplacement(),
)

In [24]:
advantage_module = GAE(
    gamma=gamma, lmbda=lmbda, value_network=value_module, average_gae=True
)

loss_module = ClipPPOLoss(
    actor_network=policy_module,
    critic_network=value_module,
    clip_epsilon=clip_epsilon,
    entropy_bonus=bool(entropy_eps),
    entropy_coef=entropy_eps,
    # these keys match by default but we set this for completeness
    critic_coef=1.0,
    loss_critic_type="smooth_l1",
)

optim = torch.optim.Adam(loss_module.parameters(), lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optim, total_frames // frames_per_batch, 0.0
)

In [25]:
logs = defaultdict(list)
pbar = tqdm(total=total_frames)
eval_str = ""

# We iterate over the collector until it reaches the total number of frames it was
# designed to collect:
for i, tensordict_data in enumerate(collector):
    # we now have a batch of data to work with. Let's learn something from it.
    for _ in range(num_epochs):
        # We'll need an "advantage" signal to make PPO work.
        # We re-compute it at each epoch as its value depends on the value
        # network which is updated in the inner loop.
        advantage_module(tensordict_data)
        data_view = tensordict_data.reshape(-1)
        replay_buffer.extend(data_view.cpu())
        for _ in range(frames_per_batch // sub_batch_size):
            subdata = replay_buffer.sample(sub_batch_size)
            loss_vals = loss_module(subdata.to(device))
            loss_value = (
                loss_vals["loss_objective"]
                + loss_vals["loss_critic"]
                + loss_vals["loss_entropy"]
            )

            # Optimization: backward, grad clipping and optimization step
            loss_value.backward()
            # this is not strictly mandatory but it's good practice to keep
            # your gradient norm bounded
            torch.nn.utils.clip_grad_norm_(loss_module.parameters(), max_grad_norm)
            optim.step()
            optim.zero_grad()

    logs["reward"].append(tensordict_data["next", "reward"].mean().item())
    pbar.update(tensordict_data.numel())
    cum_reward_str = (
        f"average reward={logs['reward'][-1]: 4.4f} (init={logs['reward'][0]: 4.4f})"
    )
    logs["step_count"].append(tensordict_data["step_count"].max().item())
    stepcount_str = f"step count (max): {logs['step_count'][-1]}"
    logs["lr"].append(optim.param_groups[0]["lr"])
    lr_str = f"lr policy: {logs['lr'][-1]: 4.4f}"
    if i % 10 == 0:
        # We evaluate the policy once every 10 batches of data.
        # Evaluation is rather simple: execute the policy without exploration
        # (take the expected value of the action distribution) for a given
        # number of steps (1000, which is our ``env`` horizon).
        # The ``rollout`` method of the ``env`` can take a policy as argument:
        # it will then execute this policy at each step.
        with set_exploration_type(ExplorationType.MEAN), torch.no_grad():
            # execute a rollout with the trained policy
            eval_rollout = env.rollout(1000, policy_module)
            logs["eval reward"].append(eval_rollout["next", "reward"].mean().item())
            logs["eval reward (sum)"].append(
                eval_rollout["next", "reward"].sum().item()
            )
            logs["eval step_count"].append(eval_rollout["step_count"].max().item())
            eval_str = (
                f"eval cumulative reward: {logs['eval reward (sum)'][-1]: 4.4f} "
                f"(init: {logs['eval reward (sum)'][0]: 4.4f}), "
                f"eval step-count: {logs['eval step_count'][-1]}"
            )
            del eval_rollout
    pbar.set_description(", ".join([eval_str, cum_reward_str, stepcount_str, lr_str]))

    # We're also using a learning rate scheduler. Like the gradient clipping,
    # this is a nice-to-have but nothing necessary for PPO to work.
    scheduler.step()


  0%|                                                                                                                             | 0/10000 [18:33<?, ?it/s][A


AttributeError: 'Sequential' object has no attribute 'get_dist'