In [None]:
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt


import spikeinterface.extractors as se
import spikeinterface.preprocessing as spre
import spikeinterface.widgets as sw

# the generator is in the "spikeinterface_generator.py"
from spikeinterface_generator import SpikeInterfaceGenerator


%matplotlib widget

### Load NP2 dataset (and preprocess)

In [None]:
# example of data generation in spike interface
folder_path = "/home/alessio/Documents/data/allen/npix-open-ephys/595262_2022-02-22_16-47-26/"
recording = se.read_openephys(folder_path)

In [None]:
rec_f = spre.bandpass_filter(recording)
rec_norm = spre.zscore(rec_f)

In [None]:
sw.plot_timeseries(rec_f, backend="ipywidgets")

### Test SpikeInterfaceGenerator behavior

In [None]:
si_generator = SpikeInterfaceGenerator(rec_norm, batch_size=10, zscore=False)

In [None]:
input_0, output_0 = si_generator[0]

In [None]:
si_generator.batch_size

In [None]:
input_0.shape

In [None]:
output_0.shape

### Perform small training

In [None]:
from deepinterpolation.trainor_collection import core_trainer
from deepinterpolation.network_collection import unet_single_ephys_1024
from deepinterpolation.generic import ClassLoader
from pathlib import Path

In [None]:
start_frame_training = int(0 * rec_norm.sampling_frequency)
end_frame_training = int(0.1 * rec_norm.sampling_frequency)
start_frame_test = int(20 * rec_norm.sampling_frequency) 
end_frame_test = int(20.05 * rec_norm.sampling_frequency)

In [None]:
# Training (from Training class)
output_folder = Path("test_training")
output_folder.mkdir(exist_ok=True)

training_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, 
                                                  start_frame=start_frame_training,
                                                  end_frame=end_frame_training)
test_data_generator = SpikeInterfaceGenerator(rec_norm, zscore=False, 
                                              start_frame=start_frame_test,
                                              end_frame=end_frame_test)


 # Those are parameters used for the network topology
network_params = dict()
network_params["type"] = "network"
network_params[
    "name"
] = "unet_single_ephys_1024"  # Name of network topology in the collection

network_json_path = output_folder / "network_params.json"
with open(network_json_path, "w") as f:
    json.dump(network_params, f)

network_obj = ClassLoader(network_json_path)
data_network = network_obj.find_and_build()(network_json_path)

training_params = dict()
training_params["loss"] = "mean_absolute_error"

training_params["model_string"] = (
    "unet_single_ephys_1024"
    + "_"
    + training_params["loss"]
)
training_params["output_dir"] = str(output_folder)
# We pass on the uid
training_params["run_uid"] = "first_test"

# We convert to old schema
training_params["nb_gpus"] = 1
training_params["type"] = "trainer"
training_params["steps_per_epoch"] = 10
training_params["period_save"] = 5
training_params["apply_learning_decay"] = 0
training_params["nb_times_through_data"] = 1
training_params["learning_rate"] = 0.0001
training_params["pre_post_frame"] = 1
training_params["loss"] = "mean_absolute_error"
training_params["nb_workers"] = 1

training_json_path = output_folder / "training_params.json"
with open(training_json_path, "w") as f:
    json.dump(training_params, f)


training_class = core_trainer(
    training_data_generator, test_data_generator, data_network,
    training_json_path
)

print("created objects for training")
training_class.run()

print("training job finished - finalizing output model")
training_class.finalize()

### Test inference

In [None]:
sample_input, original_data = test_data_generator[0]

In [None]:
output = training_class.local_model.predict(sample_input)
output_data = test_data_generator.reshape_output(output)
input_data = original_data.squeeze().reshape(-1, recording.get_num_channels())

In [None]:
fig, axs = plt.subplots(ncols=2)
axs[0].imshow(input_data.T, origin="lower", cmap="RdGy_r")
axs[1].imshow(output_data.T, origin="lower", cmap="RdGy_r")