In [7]:
import cebra
import torch
from torch.utils.data import DataLoader, TensorDataset

# Initialize the dataset
hippocampus_pos = cebra.datasets.init('rat-hippocampus-single-achilles')

# Prepare Dataset
neural_data = hippocampus_pos.neural[:, None, :]  # Reshape to [10178, 1, 120]
continuous_index = hippocampus_pos.continuous_index.numpy()
dataset = TensorDataset(torch.tensor(neural_data, dtype=torch.float32), torch.tensor(continuous_index, dtype=torch.float32))
dataloader = DataLoader(dataset, batch_size=512, shuffle=False)

  dataset = TensorDataset(torch.tensor(neural_data, dtype=torch.float32), torch.tensor(continuous_index, dtype=torch.float32))


In [8]:
import torch.nn as nn

class Skip(nn.Module):
    def __init__(self, *modules, crop=(1, 1)):
        super().__init__()
        self.module = nn.Sequential(*modules)
        self.crop = slice(crop[0], -crop[1] if isinstance(crop[1], int) and crop[1] > 0 else None)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        skip = self.module(inp)
        return inp[..., self.crop] + skip

class Offset10Model(nn.Module):
    def __init__(self, num_neurons, num_units, num_output, normalize=True):
        super().__init__()
        layers = [
            nn.Conv1d(num_neurons, num_units, 2),
            nn.GELU(),
            Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()),
            nn.Conv1d(num_units, num_output, 3)
        ]
        self.net = nn.Sequential(*layers)
        self.normalize = normalize
        if normalize:
            self.norm = nn.LayerNorm(num_output)

    def forward(self, x):
        x = self.net(x)
        if self.normalize:
            x = x.transpose(1, 2)  # Move the output_dim to the last dimension for LayerNorm
            x = self.norm(x)
            x = x.transpose(1, 2)  # Move it back
        return x

# Model Initialization
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Offset10Model(num_neurons=1, num_units=128, num_output=3, normalize=False).to(device)

In [None]:
def contrastive_loss(features, labels, temperature=0.07):
    batch_size = features.shape[0]
    features = features.view(batch_size, -1)  # Flatten features to [batch_size, output_dim]

    labels = labels.contiguous().view(-1, 1)
    mask = torch.eq(labels, labels.T).float().to(device)

    anchor_dot_contrast = torch.div(torch.matmul(features, features.T), temperature)

    logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
    logits = anchor_dot_contrast - logits_max.detach()

    mask = mask.repeat(1, batch_size)
    logits_mask = torch.scatter(
        torch.ones_like(mask),
        1,
        torch.arange(batch_size).view(-1, 1).to(device),
        0
    )
    mask = mask * logits_mask

    exp_logits = torch.exp(logits) * logits_mask
    log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True))

    mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1)

    loss = -mean_log_prob_pos
    loss = loss.mean()

    return loss

In [6]:
import torch.optim as optim

optimizer = optim.Adam(model.parameters(), lr=3e-4)

model.train()
for epoch in range(10000):
    for data, target in dataloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data).squeeze(-1)  # Remove the last dimension
        loss = contrastive_loss(output, target)
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item()}")

# Save the model
torch.save(model.state_dict(), "cebra_behavior_model.pt")

RuntimeError: The size of tensor a (512) must match the size of tensor b (786432) at non-singleton dimension 1