In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np

In [None]:
DA056_bags = pd.read_parquet("/Users/marco/Local_Sorting/DA056 Sorted/spike_bags.parquet")
DA056_spikes = pd.read_parquet("/Users/marco/Local_Sorting/DA056 Sorted/spikes.parquet")

bags = DA056_bags
spikes = DA056_spikes

# merge on bag_id and unit_id for safety
df = spikes.merge(
    bags[["bag_id", "unit_id", "spike_type", "alignment",
          "mean_waveform", "n_spikes", "channel"]],
    on=["bag_id", "unit_id"],
    how="left"
)

# binary label
df["label"] = (df["spike_type"] == "good_units").astype(int)


In [None]:
# ---- Prepare data (same idea as above) ----
X = np.stack(df["waveform"].to_numpy()).astype(np.float32)
y = df["label"].to_numpy().astype(np.float32)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

X_train_t = torch.from_numpy(X_train)
X_test_t  = torch.from_numpy(X_test)
y_train_t = torch.from_numpy(y_train)
y_test_t  = torch.from_numpy(y_test)

train_ds = TensorDataset(X_train_t, y_train_t)
test_ds  = TensorDataset(X_test_t,  y_test_t)

train_loader = DataLoader(train_ds, batch_size=128, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=128)

# ---- Device (uses Metal on M1/M2 if available) ----
device = "mps" if torch.backends.mps.is_available() else "cpu"
print("Using device:", device)

# ---- Simple MLP ----
input_dim = X_train.shape[1]

class WaveformMLP(nn.Module):
    def __init__(self, in_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(in_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(128, 1)
            )


    def forward(self, x):
        return self.net(x).squeeze(-1)

model = WaveformMLP(input_dim).to(device)
pos_weight = torch.tensor([ (y_train == 0).sum() / (y_train == 1).sum() ],
                          device=device,
                          dtype=torch.float32)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ---- Training loop (very bare-bones) ----
for epoch in range(20):  # start with 5 epochs
    model.train()
    total_loss = 0.0
    for xb, yb in train_loader:
        xb = xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

    avg_loss = total_loss / len(train_ds)
    print(f"Epoch {epoch+1}, train loss: {avg_loss:.4f}")

# ---- Simple evaluation ----
model.eval()
all_preds = []
all_targets = []

with torch.no_grad():
    for xb, yb in test_loader:
        xb = xb.to(device)
        logits = model(xb)
        probs = torch.sigmoid(logits)
        preds = (probs > 0.5).float().cpu().numpy()
        all_preds.append(preds)
        all_targets.append(yb.numpy())

y_pred = np.concatenate(all_preds)
y_true = np.concatenate(all_targets)

from sklearn.metrics import classification_report, confusion_matrix

print("Confusion matrix:")
print(confusion_matrix(y_true, y_pred))
print("\nClassification report:")
print(classification_report(y_true, y_pred))


Using device: mps
Epoch 1, train loss: 0.6128
Epoch 2, train loss: 0.5643
Epoch 3, train loss: 0.5481
Epoch 4, train loss: 0.5393
Epoch 5, train loss: 0.5306
Epoch 6, train loss: 0.5274
Epoch 7, train loss: 0.5228
Epoch 8, train loss: 0.5184
Epoch 9, train loss: 0.5166
Epoch 10, train loss: 0.5114
Epoch 11, train loss: 0.5087
Epoch 12, train loss: 0.5074
Epoch 13, train loss: 0.5062
Epoch 14, train loss: 0.5044
Epoch 15, train loss: 0.4999
Epoch 16, train loss: 0.4983
Epoch 17, train loss: 0.4980
Epoch 18, train loss: 0.4941
Epoch 19, train loss: 0.4928
Epoch 20, train loss: 0.4930
Confusion matrix:
[[23128  5375]
 [  818  5635]]

Classification report:
              precision    recall  f1-score   support

         0.0       0.97      0.81      0.88     28503
         1.0       0.51      0.87      0.65      6453

    accuracy                           0.82     34956
   macro avg       0.74      0.84      0.76     34956
weighted avg       0.88      0.82      0.84     34956

