# This notebook uses a Convolutional Neural Net (CNN) to predict Grid Cell alignment in real time using VR trajectory as labels.

In [1]:
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import seaborn as sns
import pandas as pd
import os
import glob

In [2]:
import nibabel as nib
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import KFold, train_test_split

## 1. Load and preprocess the input NIFTI images

In [5]:
# Set base paths and subject IDs
base_data_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\preprocessed'
base_behavioral_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset'

subjects = ['s05', 's14']  # List of subjects to include

In [6]:
# Set paths and parameters
# data_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\preprocessed\s05'
# behavioral_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset\s05\BehavioralData_s05'
class BrainDataset(Dataset):
    def __init__(self, subjects, base_data_dir, base_behavioral_dir, time_interval):
        self.images = []
        self.labels = []

        # Loop through each subject
        for subject in subjects:
            # Construct data and behavioral paths
            data_dir = os.path.join(base_data_dir, subject)
            behavioral_dir = os.path.join(base_behavioral_dir, subject, f'BehavioralData_{subject}')
            run_dirs = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if d.startswith('run')]

            if not run_dirs:
                print(f"No runs found for subject {subject} in {data_dir}")
                continue

            for run_dir in run_dirs:
                try:
                    # Load images and labels
                    nii_files = sorted(os.listdir(os.path.join(run_dir, 'masked_outputs_rightEC')))
                    run_images = [nib.load(os.path.join(run_dir, 'masked_outputs_rightEC', f)).get_fdata() for f in nii_files]
                    run_base = os.path.basename(run_dir).split('_')[0]

                    # Match behavioral file
                    search_pattern = os.path.join(behavioral_dir, f"*{run_base}*.tsv")
                    behavioral_files = glob.glob(search_pattern)
                    if not behavioral_files:
                        print(f"No behavioral files found for run {run_base}")
                        continue

                    behavioral_file = behavioral_files[0]
                    run_behavioral_data = pd.read_csv(behavioral_file, sep='\t')

                    # Synchronize images and labels
                    orientations = run_behavioral_data['Orientation'].values
                    timestamps = run_behavioral_data['Time'].values
                    time_points = np.arange(0, time_interval * len(run_images), time_interval)
                    labels = np.interp(time_points, timestamps, orientations)

                    if not run_images or not labels.any():
                        print(f"Skipping run {run_dir} due to missing data")
                        continue

                    self.images.extend(run_images)
                    self.labels.extend(labels)
                except Exception as e:
                    print(f"Error processing run {run_dir}: {e}")
                    continue

        # Compute dataset-wide statistics
        all_images = torch.cat([torch.tensor(img, dtype=torch.float32).unsqueeze(0) for img in self.images])
        self.mean = torch.mean(all_images)
        self.std = torch.std(all_images)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.tensor(self.images[idx], dtype=torch.float32).unsqueeze(0)  # Add channel dimension
        image = (image - self.mean) / self.std  # Normalize using dataset-wide stats
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return image, label


In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GridCellCNN(nn.Module):
    def __init__(self):
        super(GridCellCNN, self).__init__()
        self.conv1 = nn.Conv3d(1, 128, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool3d(kernel_size=2, stride=2)
        self.conv2 = nn.Conv3d(128, 256, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv3d(256, 512, kernel_size=3, stride=1, padding=1)
        
        self.fc1 = None  # Placeholder for the first fully connected layer
        self.dropout = nn.Dropout(0.5)  # Regularization
        self.fc2 = nn.Linear(256, 2)  # Predict sine and cosine

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        
        # Dynamically calculate flattened size and define fc1
        if self.fc1 is None:
            flattened_size = x.view(x.size(0), -1).size(1)
            self.fc1 = nn.Linear(flattened_size, 256)

        x = x.view(x.size(0), -1)  # Flatten the tensor
        x = F.relu(self.fc1(x))
        x = self.dropout(x)  # Apply dropout
        x = self.fc2(x)
        return x


In [8]:
import torch.nn.functional as F

def angular_loss(y_pred, y_true):
    # Normalize predictions and true values to ensure they lie on the unit circle
    y_pred = F.normalize(y_pred, p=2, dim=-1)
    y_true = F.normalize(y_true, p=2, dim=-1)
    # Compute cosine similarity
    cosine_similarity = torch.sum(F.normalize(y_pred, p=2, dim=-1) * F.normalize(y_true, p=2, dim=-1), dim=-1)
    return 1 - cosine_similarity.mean() + 0.01 * (y_pred.norm(p=2) + y_true.norm(p=2))



In [9]:
# Train function
def train_model(model, train_loader, optimizer, num_epochs=5):
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for images, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(images)  # Predict sin and cos
            # Convert labels (angles) to sin and cos
            labels_sin_cos = torch.stack([torch.sin(labels), torch.cos(labels)], dim=1)
            loss = angular_loss(outputs, labels_sin_cos)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}')



## 2. Load and Preprocess the VR trajectory data:

In [10]:
subjects = ['s05', 's14']
base_data_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\preprocessed'
base_behavioral_dir = r'C:\Users\sdabiri\OneDrive - Georgia Institute of Technology\BMED 8803 - Stat ML for Neural data\Project\Small_Dataset'
time_interval = 1.5

# Create the dataset
dataset = BrainDataset(subjects, base_data_dir, base_behavioral_dir, time_interval)
print(f"Dataset size: {len(dataset)}")

# Split into training and validation datasets
from torch.utils.data import random_split, DataLoader

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False)


Dataset size: 6747


In [12]:
from torchvision.transforms import Compose, RandomRotation, RandomHorizontalFlip

transform = Compose([
    RandomRotation(10),  # Rotate up to 15 degrees|
    RandomHorizontalFlip(),  # Flip images horizontally
])


In [13]:
for images, labels in train_loader:
    print(f"Train batch shape: {images.shape}, Labels: {labels}")
    break

for images, labels in val_loader:
    print(f"Validation batch shape: {images.shape}, Labels: {labels}")
    break


Train batch shape: torch.Size([4, 1, 96, 96, 20]), Labels: tensor([  0.6849, 240.6274, 299.6090, 317.1288])
Validation batch shape: torch.Size([4, 1, 96, 96, 20]), Labels: tensor([265.3368,  65.7651,  46.5793, 344.9536])


## 3. Set up cross-validation: 

## 4. Build and train the CNN:

In [14]:
# Initialize model, optimizer
model = GridCellCNN()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)

In [15]:
num_epochs = 50  # Arbitrary high value
patience = 5
best_val_loss = float('inf')
wait = 0

In [None]:
# Initialize lists to store metrics
train_losses = []
val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(num_epochs):
    for param_group in optimizer.param_groups:
        print(f"Learning Rate: {param_group['lr']}")

    # Training phase
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(images)
        labels_sin_cos = torch.stack([torch.sin(labels), torch.cos(labels)], dim=1)
        loss = angular_loss(outputs, labels_sin_cos)
        loss.backward()
        optimizer.step()

        # Metrics
        running_loss += loss.item()
        predicted_angles = torch.atan2(outputs[:, 0], outputs[:, 1]) * 180 / torch.pi
        predicted_angles = predicted_angles % 360  # Ensure angles are in [0, 360)
        correct += (torch.abs(predicted_angles - labels) % 360 <= 10).sum().item()  # Accuracy within 10 degrees
        total += labels.size(0)
    train_losses.append(running_loss / len(train_loader))
    train_accuracies.append(100 * correct / total)

    # Validation phase
    model.eval()
    val_running_loss = 0.0
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            outputs = model(images)
            labels_sin_cos = torch.stack([torch.sin(labels), torch.cos(labels)], dim=1)
            loss = angular_loss(outputs, labels_sin_cos)
            val_running_loss += loss.item()

            # Metrics
            predicted_angles = torch.atan2(outputs[:, 0], outputs[:, 1]) * 180 / torch.pi
            predicted_angles = predicted_angles % 360
            val_correct += (torch.abs(predicted_angles - labels) % 360 <= 10).sum().item()
            val_total += labels.size(0)
    val_losses.append(val_running_loss / len(val_loader))
    val_accuracies.append(100 * val_correct / val_total)

    # Adjust learning rate
    scheduler.step(val_running_loss / len(val_loader))   # Uses validation loss


    # Print metrics for the epoch
    print(f"Epoch [{epoch+1}/{num_epochs}], "
          f"Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accuracies[-1]:.2f}%, "
          f"Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accuracies[-1]:.2f}%")

    if val_running_loss / len(val_loader) < best_val_loss:
        best_val_loss = val_running_loss / len(val_loader)
        wait = 0
        torch.save(model.state_dict(), 'best_model.pth')  # Save best model
    else:
        wait += 1
        if wait >= patience:
            print(f"Stopping early at epoch {epoch+1}")
            break



Learning Rate: 0.001


In [None]:
import matplotlib.pyplot as plt

# Plot training and validation loss
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

# Plot training and validation accuracy
plt.figure(figsize=(10, 5))
plt.plot(train_accuracies, label='Training Accuracy')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()
