In [6]:
import os
import pytorch_lightning as pl
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as data

import kerasncp as kncp
from kerasncp.torch import LTCCell
from kerasncp import wirings
print('kncp version:', kncp.__version__)

from src.DGAProcess import LearningProcess
from src.DGALoss import DGALoss
from src.DGANet import DGANet
from src.DgaSequence import DgaSequence

kncp version: 2.0.0


In [None]:
class DenoiseSequence(nn.Module):
    def __init__(self, wiring, initialization_ranges, in_features):
        super().__init__()
        self.cell = LTCCell(
            ncp_wiring,
            initialization_ranges={
                # Overwrite some of the initialization ranges
                "w": (0.2, 2.0),
            },
        )

        self.prenet = 

In [None]:

# LightningModule for training a RNNSequence module
class SequenceLearner(pl.LightningModule):
    def __init__(self, model, lr=0.005):
        super().__init__()
        self.model = model
        self.lr = lr

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)
        self.log("train_loss", loss, prog_bar=True)
        return {"loss": loss}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model.forward(x)
        y_hat = y_hat.view_as(y)
        loss = nn.MSELoss()(y_hat, y)

        self.log("val_loss", loss, prog_bar=True)
        return loss

    def test_step(self, batch, batch_idx):
        # Here we just reuse the validation_step for testing
        return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        return torch.optim.Adam(self.model.parameters(), lr=self.lr)

    def optimizer_step(
        self,
        current_epoch,
        batch_nb,
        optimizer,
        optimizer_idx,
        closure,
        on_tpu=False,
        using_native_amp=False,
        using_lbfgs=False,
    ):
        optimizer.optimizer.step(closure=closure)
        # Apply weight constraints
        self.model.rnn_cell.apply_weight_constraints()

In [None]:

in_features = 16
out_features = 1
batch_size = 6
N = 16000  # Length of the time-series

# Input feature is a sine and a cosine wave
data_x = np.stack(
    [np.sin(np.linspace(0, 3 * np.pi, N)), np.cos(np.linspace(0, 3 * np.pi, N))], axis=1
)
data_x = np.expand_dims(data_x, axis=0).astype(np.float32)  # Add batch dimension
# Target output is a sine with double the frequency of the input signal
data_y = np.sin(np.linspace(0, 6 * np.pi, N)).reshape([1, N, 1]).astype(np.float32)
data_x = torch.Tensor(data_x)
data_y = torch.Tensor(data_y)
print("data_x.size: ", str(data_x.size()))
print("data_y.size: ", str(data_y.size()))


dataloader = data.DataLoader(
    data.TensorDataset(data_x, data_y),
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

#######################
ncp_wiring = kncp.wirings.NCP(
    inter_neurons=20,  # Number of inter neurons
    command_neurons=10,  # Number of command neurons
    motor_neurons=5,  # Number of motor neurons
    sensory_fanout=4,  # How many outgoing synapses has each sensory neuron
    inter_fanout=5,  # How many outgoing synapses has each inter neuron
    recurrent_command_synapses=6,  # Now many recurrent synapses are in the
    # command neuron layer
    motor_fanin=4,  # How many incoming synapses has each motor neuron
)
ncp_cell = LTCCell(
    ncp_wiring,
    initialization_ranges={
        # Overwrite some of the initialization ranges
        "w": (0.2, 2.0),
    },
)
#######################

# wiring = kncp.wirings.FullyConnected(8, out_features)  # 16 units, 8 motor neurons
# ltc_cell = LTCCell(wiring, in_features)




ltc_sequence = RNNSequence(
    ltc_cell,
)

learn = SequenceLearner(ltc_sequence, lr=0.01)
trainer = pl.Trainer(
    logger=pl.loggers.CSVLogger("log"),
    max_epochs=400,
    progress_bar_refresh_rate=1,
    gradient_clip_val=1,  # Clip gradient to stabilize training
    gpus=1,
)

trainer.fit(learn, dataloader)

results = trainer.test(learn, dataloader)
