# Tutorial for Trials in Elephant

This notebook provides a tutorial on how to use the `Trials` classes in Elephant for handling neuroscience data structured in repeated trials. The `Trials` classes provide a unified API to access and manipulate trial data, regardless of whether the data is stored in `neo.Block` or as lists of lists.

## Introduction

In neuroscience, trials are repeated experimental runs that help in estimating various quantities from the data. The Trials class in Elephant provides a structured way to handle and access trial data, irrespective of how it is stored. This overview will cover the basic structure and functionalities of the Trials class and its subclasses.

## Imports and Setup

First, we need to import the necessary modules and set up some example data.

In [24]:
import numpy as np
from neo.core import Block, Segment, SpikeTrain, AnalogSignal
from quantities import ms, mV, Hz
from elephant.trials import TrialsFromBlock, TrialsFromLists

# Helper function to create example spike trains and analog signals
def create_example_data():
    spike_times = [np.array([10, 20, 30]) * ms, np.array([15, 25, 35]) * ms]
    spike_trains = [SpikeTrain(spike_time, t_start=0*ms, t_stop=40*ms) for spike_time in spike_times]
    analog_signals = [AnalogSignal(np.random.randn(1000), sampling_rate=1000 * Hz, units=mV) for _ in range(2)]
    return spike_trains, analog_signals


## Example 1: Using `TrialsFromBlock`

In this example, we create a `neo.Block` containing multiple `neo.Segment` objects, each representing a trial. We then use the `TrialsFromBlock` class to access and manipulate the trial data,
this subclass handles trial data stored in a `neo.Block`, where each `neo.Segment` within the block represents a trial.

In [25]:
# Create example data
spike_trains, analog_signals = create_example_data()

# Create a neo.Block and add Segments for each trial
block = Block()
for i in range(3):
    segment = Segment(name=f'Trial {i}')
    for spike_train in spike_trains:
        segment.spiketrains.append(spike_train)
    for analog_signal in analog_signals:
        segment.analogsignals.append(analog_signal)
    block.segments.append(segment)

# Create a TrialsFromBlock instance
trials_from_block = TrialsFromBlock(block)

# Access trial data
print(f'Number of trials: {trials_from_block.n_trials}')
print(f'Number of spike trains in each trial: {trials_from_block.n_spiketrains_trial_by_trial}')
print(f'Number of analog signals in each trial: {trials_from_block.n_analogsignals_trial_by_trial}')

# Get a specific trial as a Segment
trial_1_segment = trials_from_block.get_trial_as_segment(1)
print(f'Trial 1 Segment: {trial_1_segment}')

# Get all spike trains from trial
all_spike_trains_trial_2 = trials_from_block.get_spiketrains_from_trial_as_list(trial_id=1)
print(f'All spike trains from trial 2: {all_spike_trains_trial_2}')

Number of trials: 3
Number of spike trains in each trial: [2, 2, 2]
Number of analog signals in each trial: [2, 2, 2]
Trial 1 Segment: <neo.core.segment.Segment object at 0x7af0a0e4fac0>
All spike trains from trial 2: [<SpikeTrain(array([10., 20., 30.]) * ms, [0.0 ms, 40.0 ms])>, <SpikeTrain(array([15., 25., 35.]) * ms, [0.0 ms, 40.0 ms])>]


## Example 2: Using `TrialsFromLists`

The `TrialsFromLists` handles trial data stored as lists of lists, where each inner list contains neo.SpikeTrain and neo.AnalogSignal objects representing a trial. In this example, we create a list of lists, where each inner list contains `neo.SpikeTrain` and `neo.AnalogSignal` objects representing a trial. We then use the `TrialsFromLists` class to access and manipulate the trial data. 

In [26]:
# Create example data
spike_times, analog_signals = create_example_data()

# Create a list of lists representing trials
list_of_trials = []
for i in range(3):
    trial = []
    for spike_time in spike_times:
        trial.append(SpikeTrain(spike_time, t_start=0*ms, t_stop=40*ms))
    for analog_signal in analog_signals:
        trial.append(analog_signal)
    list_of_trials.append(trial)

# Create a TrialsFromLists instance
trials_from_lists = TrialsFromLists(list_of_trials)

# Access trial data
print(f'Number of trials: {trials_from_lists.n_trials}')
print(f'Number of spike trains in each trial: {trials_from_lists.n_spiketrains_trial_by_trial}')
print(f'Number of analog signals in each trial: {trials_from_lists.n_analogsignals_trial_by_trial}')

# Get a specific trial as a Segment
trial_1_segment = trials_from_lists.get_trial_as_segment(1)
print(f'Trial 1 Segment: {trial_1_segment}')

# Get all trials as a Block
all_spike_trains_trial_2 = trials_from_lists.get_trials_as_block()
print(f'All trials Block: {all_spike_trains_trial_2}')

# Using TrialsFromLists
trials_as_list = trials_from_lists.get_trials_as_list()
print(f'Number of segments in trials_list_from_lists: {len(trials_as_list)}')
for i, segment in enumerate(trials_as_list):
    print(f'Trial {i} Segment: {segment}')
    for i, spike_train in enumerate(segment.spiketrains):
        print(f'    Spike train {i} Spike train: {spike_train}')
    for i, analog_signal in enumerate(segment.analogsignals):
        print(f'    Analog signal {i} Analog signal: {analog_signal[:4].flatten()}')

Number of trials: 3
Number of spike trains in each trial: [2, 2, 2]
Number of analog signals in each trial: [2, 2, 2]
Trial 1 Segment: <neo.core.segment.Segment object at 0x7af0a0cbe620>
All trials Block: <neo.core.block.Block object at 0x7af0a0cbe8c0>
Number of segments in trials_list_from_lists: 3
Trial 0 Segment: <neo.core.segment.Segment object at 0x7af0a0e4df90>
    Spike train 0 Spike train: [10. 20. 30.] ms
    Spike train 1 Spike train: [15. 25. 35.] ms
    Analog signal 0 Analog signal: [ 0.69607325 -0.04370697  1.36631824 -0.83783295] mV
    Analog signal 1 Analog signal: [ 0.79122003  0.40844734  1.0958819  -0.83324856] mV
Trial 1 Segment: <neo.core.segment.Segment object at 0x7af0a0e4dea0>
    Spike train 0 Spike train: [10. 20. 30.] ms
    Spike train 1 Spike train: [15. 25. 35.] ms
    Analog signal 0 Analog signal: [ 0.69607325 -0.04370697  1.36631824 -0.83783295] mV
    Analog signal 1 Analog signal: [ 0.79122003  0.40844734  1.0958819  -0.83324856] mV
Trial 2 Segment: 