In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import cebra_v2 as cebra2

In [25]:
import importlib
importlib.reload(cebra2.dataset)
importlib.reload(cebra2.solver)

<module 'cebra_v2.solver' from '/volatile/aurelien_stumpf_mascles/project/code/cebra_v2/solver.py'>

In [6]:
loaded_arrays = np.load('/neurospin/lbi/monkeyfmri/deepstim/database/SYNTHETIC_database/sub-200_states-4_noise-high_synth/dataset.npz')
print(loaded_arrays.files)

['data25', 'dfc25', 'labels25', 'data30', 'dfc30', 'labels30', 'data35', 'dfc35', 'labels35', 'data40', 'dfc40', 'labels40', 'data45', 'dfc45', 'labels45', 'states']


In [10]:
data25 = loaded_arrays['data25']
array = np.mean(data25,axis = 3)[0]
fc_dataset = cebra2.dataset.TensorDataset(array)

In [11]:
num_output = 3
normalize = True
num_neurons = 82

model = cebra2.model.Model(
    nn.Flatten(start_dim=1, end_dim=-1),
    nn.Linear(
        num_neurons,
        num_output * 30,
    ),
    nn.GELU(),
    nn.Linear(num_output * 30, num_output * 30),
    nn.GELU(),
    nn.Linear(num_output * 30, num_output * 10),
    nn.GELU(),
    nn.Linear(int(num_output * 10), num_output),
    num_input=num_neurons,
    num_output=num_output,
    normalize=normalize,
        )

In [12]:
def single_session_solver(data_loader, **kwargs):
    """Train a single session CEBRA model."""
    norm = True
    if kwargs['distance'] == 'euclidean':
        norm = False
    model = kwargs["model"]

    if kwargs['distance'] == 'euclidean':
        criterion = cebra2.criterion.InfoMSE(temperature=kwargs['temperature'])
    elif kwargs['distance'] == 'cosine':        
        criterion = cebra2.criterion.InfoNCE(temperature=kwargs['temperature'])

    optimizer = torch.optim.Adam(model.parameters(), lr=kwargs['learning_rate'])

    return cebra2.solver.SingleSessionSolver(model=model,
                                            criterion=criterion,
                                            optimizer=optimizer)

In [19]:
@torch.no_grad()
def get_emissions(model, dataset):
    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    model.to(device)
    return model(dataset[torch.arange(len(dataset))].to(device)).cpu().numpy()

def _compute_emissions_single(solver, dataset):
    return get_emissions(solver.model, dataset)

In [26]:
fc_loader = cebra2.dataset.Loader(fc_dataset, num_steps = 10000, batch_size = 512)

In [27]:
cebra_fc = single_session_solver(data_loader = fc_loader, model_architecture = 'offset1-model', 
                 distance = 'cosine', num_hidden_units = 128, output_dimension = 128,
                verbose = True, temperature = 1, learning_rate = 3e-2, model = model)

In [28]:
cebra_fc.fit(fc_loader)

tensor([101, 379, 110,  ..., 356, 441, 311]) tensor([116, 394, 125,  ..., 371, 456, 326]) tensor([176, 166, 246, 181, 129,  81, 135, 319, 192, 319, 300, 171, 137, 195,
        387, 358, 425, 450,  34, 351, 295,  90, 243, 225, 466, 193, 296, 233,
        344, 123, 185,  74,  22, 305, 274, 446,  47, 104, 271, 156, 432, 383,
        445, 426, 220,  45, 396, 198, 381, 452,  86, 230, 323, 127, 257, 386,
        320, 396, 279, 374,  92, 155, 451, 320,  49, 388, 366, 217, 265, 472,
        161, 449, 414,  56,  28, 157,  96, 424, 178,  36, 230, 257, 221, 292,
         33, 234,   8,  22, 269,  53, 117, 231,  42, 132, 114, 401, 229, 102,
        197, 341,  25, 453, 257,  84, 441, 332, 143, 186, 129, 309, 400, 117,
         28, 249, 449,  30,  48, 255, 234, 378, 251, 136, 318, 339,  64, 314,
        341, 376, 246, 140, 448, 161, 388, 156, 400,  62, 289, 357,  27,  52,
        252, 444, 344, 453,   1, 461, 151, 404, 428, 191, 275,  70, 287, 353,
        118,  20, 112, 131, 287, 443,  40, 246, 140,

IndexError: index 475 is out of bounds for dimension 0 with size 475