In [1]:
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn as nn
from torch import optim
from torch.autograd import Variable
import matplotlib.pyplot as plt
import os, pickle
import numpy as np

In [3]:

train_data = datasets.CIFAR10(
    root = 'CIFAR10_train_data',
    train = True,
    transform = ToTensor(),
    download = True,
)
test_data = datasets.CIFAR10(
    root = 'CIFAR10_test_data',
    train = False,
    transform = ToTensor(),
    download = True,
)

loaders = {
    'train': torch.utils.data.DataLoader(train_data,
                                         batch_size=100,
                                         shuffle=True,
                                         num_workers=4),

    'test': torch.utils.data.DataLoader(test_data,
                                        batch_size=100,
                                        shuffle=True,
                                        num_workers=4),
}

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to CIFAR10_train_data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting CIFAR10_train_data/cifar-10-python.tar.gz to CIFAR10_train_data
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to CIFAR10_test_data/cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting CIFAR10_test_data/cifar-10-python.tar.gz to CIFAR10_test_data


In [4]:

class CorrelatedGroupSelector(nn.Module):
    def __init__(self, input_dim, num_groups, group_size, temperature=1.0):
        super().__init__()
        self.input_dim = input_dim
        self.num_groups = num_groups
        self.group_size = group_size
        self.temperature = temperature

        # Learnable logits for group membership: [num_groups, input_dim]
        self.group_logits = nn.Parameter(torch.randn(num_groups, input_dim))

    def forward(self, x):
        """
        x: (batch_size, input_dim)
        Returns:
            grouped_inputs: list of tensors [(batch, group_size), ...]
            selection_mask: (num_groups, input_dim)
        """
        # Gumbel-Softmax over inputs to create a soft group selection
        gumbel_noise = -torch.log(-torch.log(torch.rand_like(self.group_logits)))
        logits = self.group_logits + gumbel_noise
        probs = F.softmax(logits / self.temperature, dim=-1)

        # Use top-k per group to select group members
        topk = torch.topk(probs, self.group_size, dim=-1)
        selection_mask = torch.zeros_like(probs)
        selection_mask.scatter_(1, topk.indices, 1.0)  # hard selection mask

        grouped_inputs = []
        for i in range(self.num_groups):
            group = selection_mask[i] * x  # broadcasted (batch, input_dim)
            grouped_inputs.append(group)

        return grouped_inputs, selection_mask

def compute_group_correlation(group):
    """
    group: (batch_size, input_dim) where non-grouped values are 0
    Returns: scalar correlation score (avg pairwise cosine)
    """
    # Only non-zero cols
    nonzero = (group.abs().sum(0) > 0)
    group_vars = group[:, nonzero]
    if group_vars.shape[1] < 2:
        return torch.tensor(0.0, device=group.device)
    normed = F.normalize(group_vars, dim=0)
    corr = (normed.T @ normed) / normed.shape[0]
    upper = torch.triu(corr, diagonal=1)
    avg_corr = upper.sum() / (nonzero.sum() * (nonzero.sum() - 1) / 2 + 1e-6)
    return avg_corr

class SelectiveGroupModel(nn.Module):
    def __init__(self, input_dim, num_groups=4, group_size=4):
        super().__init__()
        self.selector = CorrelatedGroupSelector(input_dim, num_groups, group_size)
        self.group_mlp = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_dim, 16),
                nn.ReLU(),
                nn.Linear(16, 1)
            ) for _ in range(num_groups)
        ])

    def forward(self, x):
        groups, mask = self.selector(x)

        # Compute correlation per group
        correlations = torch.stack([compute_group_correlation(g) for g in groups])
        _, selected_indices = torch.topk(-correlations, k=2)  # lowest 2 correlations

        # Only update selected groups
        outputs = []
        for i, group in enumerate(groups):
            if i in selected_indices:
                out = self.group_mlp[i](group)
            else:
                with torch.no_grad():
                    out = self.group_mlp[i](group)
            outputs.append(out)

        out = torch.cat(outputs, dim=1)
        return out, correlations

In [5]:
model = SelectiveGroupModel(input_dim=20, num_groups=5, group_size=4)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [9]:
model.selector.group_logits.shape

torch.Size([5, 20])

In [13]:
model

SelectiveGroupModel(
  (selector): CorrelatedGroupSelector()
  (group_mlp): ModuleList(
    (0): Sequential(
      (0): Linear(in_features=20, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=1, bias=True)
    )
    (1): Sequential(
      (0): Linear(in_features=20, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=1, bias=True)
    )
    (2): Sequential(
      (0): Linear(in_features=20, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=1, bias=True)
    )
    (3): Sequential(
      (0): Linear(in_features=20, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=1, bias=True)
    )
    (4): Sequential(
      (0): Linear(in_features=20, out_features=16, bias=True)
      (1): ReLU()
      (2): Linear(in_features=16, out_features=1, bias=True)
    )
  )
)

In [None]:
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0

    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        # Flatten the input images
        inputs_flat = inputs.view(inputs.size(0), -1)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs, corr_reg, scores = model(inputs_flat)

        # Compute loss
        loss = criterion(outputs, labels) - 0.1 * corr_reg  # Adjust the weight of corr_reg as needed

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {running_loss/len(train_loader)}")

In [14]:
import numpy as np
import matplotlib.pyplot as plt
import math
import sys,os, glob
from neo import io
import pandas as pd


In [15]:
base_dirr = '/ems/elsc-labs/segev-i/yoni.leibner/PycharmProjects/Hippocampus_Basu/traces/'
f = os.path.join(base_dirr, ".abf")
print(f)

In [None]:
r = io.AxonIO(f)
bl = r.read_block(lazy=False)
