In [1]:
import torch
import torch.nn as nn
import numpy as np
import torch.utils.data as data
import pytorch_lightning as pl
import matplotlib.pyplot as plt
from ncps.wirings import AutoNCP
from ncps.torch import LTC

from state_representation import ImageDataset, record_resets, CombineTransform, NormalizeTransform, ReshapeTransform
from utils.config import AE_CONFIG

  from .autonotebook import tqdm as notebook_tqdm


# Data collection and preprocessing

In [2]:
IMAGE_PATH = "./state_representation/reset_image_data_agent"
LABEL_PATH = "./state_representation/reset_label_data_agent"

In [None]:
# NOTE: uncomment if new dataset is necessary
# record_resets(IMAGE_PATH, 10000, AE_CONFIG.env)

In [4]:
transform = CombineTransform([
    NormalizeTransform(start=(0, 255), end=(0, 1)),
    ReshapeTransform((128, 128, 3))
])

dataset = ImageDataset(IMAGE_PATH, transform=transform)

In [6]:
# NOTE: uncomment if new dataset is necessary
# labels = []

# blue = torch.tensor([0., 0., 1.])
# for image in dataset:
#     agent = []
#     for i, h in enumerate(image):
#         for j, w in enumerate(h):
#             if w.equal(blue):
#                 agent.append((i, j))
    
#     center = np.round(np.array(agent).mean(axis=0)).astype(int)
#     labels.append(center)

# labels = torch.tensor(np.array(labels))
# torch.save(labels, LABEL_PATH + ".pt")
    

In [4]:
N = 1 # Length of the time-series
data_x = torch.load(IMAGE_PATH + ".pt").float()
# Target output is a sine with double the frequency of the input signal
data_y = torch.load(LABEL_PATH + ".pt").float()
print("data_x.shape: ", str(data_x.shape))
print("data_y.shape: ", str(data_y.shape))
# data_x = torch.Tensor(data_x)
# data_y = torch.Tensor(data_y)
dataloader = data.DataLoader(
    data.TensorDataset(data_x, data_y), batch_size=1, shuffle=True, num_workers=4
)

data_x.shape:  torch.Size([10000, 3, 128, 128])
data_y.shape:  torch.Size([10000, 2])


# Training a Liquid Neural Network

In [5]:
# LightningModule for training a RNNSequence module
class SequenceLearner(pl.LightningModule):
    def __init__(self, model, lr=0.005):
        super().__init__()
        self.model = model
        self.lr = lr

    def training_step(self, batch, batch_idx):
        x, y = batch
        # x = x.reshape((x.shape[0], 128*128*3))
        y_hat, _ = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        # x = x.reshape((x.shape[0], 128*128*3))
        y_hat, _ = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

In [6]:
torch.set_float32_matmul_precision('medium')
out_features = 2
in_features = 14 * 14 * 32

wiring = AutoNCP(16, out_features)  # 16 units, 1 motor neuron

model = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=8, kernel_size=5, stride=2),
    nn.BatchNorm2d(8),
    nn.ReLU(),
    nn.Conv2d(8, 16, kernel_size=3, stride=2),
    nn.BatchNorm2d(16),
    nn.ReLU(),
    nn.Conv2d(16, 32, kernel_size=3, stride=2),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Flatten(),
    LTC(in_features, wiring, batch_first=True)
)
learn = SequenceLearner(model, lr=0.01)
trainer = pl.Trainer(
    logger=pl.loggers.CSVLogger("log"),
    max_epochs=5,
    gradient_clip_val=1,  # Clip gradient to stabilize training
    accelerator="gpu",
    devices=1
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
# Train the model for 400 epochs (= training steps)
trainer.fit(learn, dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 522 K 
-------------------------------------
421 K     Trainable params
100 K     Non-trainable params
522 K     Total params
2.089     Total estimated model params size (MB)


Epoch 4: 100%|██████████| 10000/10000 [00:51<00:00, 195.72it/s, loss=6.65, v_num=15, train_loss=5.030]      

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 10000/10000 [00:51<00:00, 195.68it/s, loss=6.65, v_num=15, train_loss=5.030]


In [16]:
for i in np.random.randint(0, 10000, size=10, dtype=int):
    sample = data_x[i].unsqueeze(0)
    label = data_y[i]
    prediction = model.forward(sample)
    prediction = prediction[0].detach().numpy()

    print(f"Predicted: {prediction}, True: {label}")

Predicted: [[118.689735 105.957405]], True: tensor([120., 111.])
Predicted: [[100.16912  38.35141]], True: tensor([100.,  40.])
Predicted: [[ 85.449814 118.219376]], True: tensor([ 82., 117.])
Predicted: [[23.72602 77.41425]], True: tensor([25., 79.])
Predicted: [[ 9.97237 30.3592 ]], True: tensor([ 7., 32.])
Predicted: [[10.271697 55.066082]], True: tensor([11., 59.])
Predicted: [[10.593664 22.459108]], True: tensor([10., 23.])
Predicted: [[63.223976 73.96355 ]], True: tensor([64., 74.])
Predicted: [[11.102107 90.29119 ]], True: tensor([ 9., 93.])
Predicted: [[73.82364  32.706394]], True: tensor([72., 31.])
