## 1. Check if the model is differentiable

In [56]:
import torch
from torch.nn.parameter import Parameter

class QubeDynamics(torch.nn.Module):
    """Solve equation M qdd + C(q, qd) = tau for qdd."""

    def __init__(self):
        super().__init__()
        # Gravity
        # self.g = Parameter(data=torch.Tensor([9.81]), requires_grad=True)
        self.g = torch.tensor([9.81])

        # Motor
        self.Rm = Parameter(data=torch.Tensor([8.4]), requires_grad=True)

        # back-emf constant (V-s/rad)
        self.km = Parameter(data=torch.Tensor([0.042]), requires_grad=True)

        # Rotary arm
        self.Mr = Parameter(data=torch.Tensor([0.095]), requires_grad=True)
        self.Lr = Parameter(data=torch.Tensor([0.085]), requires_grad=True)
        self.Dr = Parameter(data=torch.Tensor([5e-6]), requires_grad=True)

        # Pendulum link
        self.Mp = Parameter(data=torch.Tensor([0.024]), requires_grad=True)
        self.Lp = Parameter(data=torch.Tensor([0.129]), requires_grad=True)
        self.Dp = Parameter(data=torch.Tensor([1e-6]), requires_grad=True)

        # Init constants
        # self._init_const()

    def set_random_params(self):
        for p in self.parameters():
            p.data = torch.rand_like(p.data)/10 # most params between 0 and 0.1

        # except for Rm
        self.Rm = Parameter(data=torch.Tensor([5]), requires_grad=True)
        
        # self._init_const()

    def _init_const(self):
        # Moments of inertia
        Jr = self.Mr * self.Lr ** 2 / 12  # inertia about COM (kg-m^2)
        Jp = self.Mp * self.Lp ** 2 / 12  # inertia about COM (kg-m^2)

        # Constants for equations of motion
        self._c = torch.zeros(5)
        self._c[0] = Jr + self.Mp * self.Lr ** 2
        self._c[1] = 0.25 * self.Mp * self.Lp ** 2
        self._c[2] = 0.5 * self.Mp * self.Lp * self.Lr
        self._c[3] = Jp + self._c[1]
        self._c[4] = 0.5 * self.Mp * self.Lp * self.g


    def forward(self, s, u, dt):
        th, al, thd, ald = s
        voltage = u[0] * 12

        # need to re-init each time we update params
        self._init_const()

        # Define mass matrix M = [[a, b], [b, c]]
        a = self._c[0] + self._c[1] * torch.sin(al) ** 2
        b = self._c[2] * torch.cos(al)
        c = self._c[3]
        d = a * c - b * b

        # Calculate vector [x, y] = tau - C(q, qd)
        trq = self.km * (voltage - self.km * thd) / self.Rm
        c0 = self._c[1] * torch.sin(2 * al) * thd * ald \
            - self._c[2] * torch.sin(al) * ald * ald
        c1 = -0.5 * self._c[1] * torch.sin(2 * al) * thd * thd \
            + self._c[4] * torch.sin(al)
        x = trq - self.Dr * thd - c0
        y = -self.Dp * ald - c1

        # Compute M^{-1} @ [x, y]
        thdd = (c * x - b * y) / d
        aldd = (a * y - b * x) / d

        next_state = torch.clone(s)
        next_state[3] += (dt * aldd)[0]
        next_state[2] += (dt * thdd)[0]
        next_state[1] += (dt * next_state[3])[0]
        next_state[0] += (dt * next_state[2])[0]

        return next_state

In [81]:
from furuta_gym.envs.furuta_sim import QubeDynamics as QD

baseline = QD()
model = QubeDynamics()

state, action, dt, next_state = ds[10]

# run model
pred_next_state = model(state, action, dt)

print(next_state)
print(pred_next_state)
# loss = torch.nn.functional.mse_loss(pred_next_state, next_state)
# print(loss)

# TODO put the state update in the model

tensor([ 0.2682, -0.2746, 10.2170, -9.8352])
tensor([ 0.2682, -0.2746, 10.2169, -9.8352], grad_fn=<CopySlices>)


In [80]:
# make a dataset
# input is state + action + dt, output is next state
import torch
from typing import Union
from pathlib import Path
import os
from furuta_gym.logging.protobuf.pendulum_state_pb2 import PendulumState
from mcap_protobuf.reader import read_protobuf_messages
from tqdm import tqdm

class MCAPDataset(torch.utils.data.Dataset):
    def __init__(self, root_dir):
        if isinstance(root_dir, str):
            root_dir = Path(root_dir)

        # parse the data
        # TODO it's gonna load it all in RAM
        # + have some duplicates
        # but should be ok since this is pretty light < 1MB
        self.samples = []
        for f in tqdm(os.listdir(root_dir)):
            if f.endswith(".mcap"):
                try:
                    self.parse_mcap(root_dir / f)
                except Exception as e:
                    print(f"Error parsing {f}: {e}")

    def parse_mcap(self, pth):
        msgs = list(read_protobuf_messages(pth, log_time_order=True))
        for i in range(1, len(msgs)-1):
            msg = msgs[i-1]
            next_msg = msgs[i]

            p = msg.proto_msg
            state = torch.tensor([p.motor_angle, p.pendulum_angle, 
                                  p.motor_angle_velocity, p.pendulum_angle_velocity],
                                  requires_grad=False,
                                  dtype=torch.float32)

            next_p = next_msg.proto_msg
            next_state = torch.tensor([next_p.motor_angle, next_p.pendulum_angle, 
                                       next_p.motor_angle_velocity, next_p.pendulum_angle_velocity],
                                       requires_grad=False,
                                       dtype=torch.float32)

            dt = torch.tensor([(next_msg.log_time - msg.log_time).total_seconds()], requires_grad=False)
            # dt = torch.tensor([1/50], requires_grad=False)
            action = torch.tensor([next_p.corrected_action], requires_grad=False)

            sample = (state, action, dt, next_state)
            self.samples.append(sample)

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

ds = MCAPDataset("../data/24adyqqm/")
print(len(ds))
ds[0]


 83%|████████▎ | 165/198 [00:00<00:00, 259.16it/s]

Error parsing ep197_20221201-015414.mcap: [Errno 22] Invalid argument


100%|██████████| 198/198 [00:00<00:00, 254.66it/s]

20773





(tensor([-0.0608,  0.0422, -2.2632,  2.2385]),
 tensor([-0.3159]),
 tensor([0.0200]),
 tensor([-0.1766,  0.1546, -5.7868,  5.6248]))

In [87]:
import wandb
config = {
    "epochs": 40,
    "batch_size": 512,
    "lr": 5e-3
}
with wandb.init(project="furuta", job_type="system_id", config=config) as run:
    config = run.config

    # setup dataset
    ds = MCAPDataset("../data/24adyqqm/")

    # hyperparameters

    # setup dataloader
    dl = torch.utils.data.DataLoader(ds, batch_size=config.batch_size, shuffle=True)

    # setup model
    model = QubeDynamics()
    model.set_random_params()
    model.train()

    # setup optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

    # setup loss
    loss = torch.nn.MSELoss()

    # train
    for epoch in range(config.epochs):
        for batch in tqdm(dl):
            # unpack batch
            state, action, dt, next_state = batch

            # reset gradients
            optimizer.zero_grad()

            # run model
            preds = []
            for i in range(state.size()[0]):
                preds.append(model(state[i], action[i], dt[i]))
            pred_next_state = torch.stack(preds, dim=0)

            # calculate loss
            l = loss(pred_next_state, next_state)

            # backprop
            l.backward()

            # update weights
            optimizer.step()

            to_log = dict(model.state_dict())
            to_log["loss"] = l
            run.log(to_log)

 14%|█▍        | 42/291 [00:00<00:02, 83.91it/s]

Error parsing ep290_20221201-022044.mcap: 


100%|██████████| 291/291 [00:02<00:00, 101.71it/s]
100%|██████████| 151/151 [02:04<00:00,  1.22it/s]
100%|██████████| 151/151 [01:57<00:00,  1.29it/s]
100%|██████████| 151/151 [02:00<00:00,  1.25it/s]
100%|██████████| 151/151 [01:54<00:00,  1.32it/s]
100%|██████████| 151/151 [01:54<00:00,  1.32it/s]
100%|██████████| 151/151 [01:57<00:00,  1.28it/s]
100%|██████████| 151/151 [02:01<00:00,  1.24it/s]
100%|██████████| 151/151 [01:54<00:00,  1.32it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:52<00:00,  1.34it/s]
100%|██████████| 151/151 [01:52<00:00,  1.34it/s]
100%|██████████| 151/151 [01:52<00:00,  1.34it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:52<00:00,  1.34it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:53<00:00,  1.33it/s]
100%|██████████| 151/151 [01:52<00:00,  1.34it/s]

0,1
Dp,▁▇▇█████████████████████████████████████
Dr,█▇▅▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
Lp,▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇██
Lr,▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▇▇█
Mp,▆████████████▇▇▇▇▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▄▄▃▃▃▂▂▁
Mr,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▇▇▇█
Rm,▆▇███▇▇▆▆▆▆▆▆▆▆▆▆▆▆▅▅▅▅▅▅▅▅▄▄▄▄▄▃▃▂▃▂▂▁▁
km,▄▃▁▁▁▂▂▃▃▃▃▃▃▃▃▃▃▃▄▄▄▄▄▄▄▄▅▅▅▅▅▆▆▆▇▇▇▇██
loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
Dp,-0.0001
Dr,-0.00059
Lp,0.08813
Lr,0.06481
Mp,0.12298
Mr,0.14489
Rm,5.04514
km,0.05107
loss,0.43632
