In [None]:
import sys
import importlib
import subprocess

def install_if_colab():
    if "google.colab" in sys.modules:
        print("Running in Google Colab. Checking required libraries...")

        packages = ["torch", "numpy", "matplotlib"]  # Add required libraries
        missing_packages = [pkg for pkg in packages if importlib.util.find_spec(pkg) is None]

        if missing_packages:
            print(f"Installing missing libraries: {', '.join(missing_packages)}")
            !pip install {" ".join(missing_packages)}
        else:
            print("All required libraries are already installed.")
    else:
        print("Not running in Google Colab. No installation needed.")

install_if_colab()


In [None]:
import numpy as np
from braindecode.datasets import MOABBDataset

subject_id = [1,2,3,4]
dataset = MOABBDataset(dataset_name="BNCI2014_001", subject_ids=[1,2,3,4,5,6,7,8,9])



from braindecode.preprocessing import (
    exponential_moving_standardize,
    preprocess,
    Preprocessor,
)

low_cut_hz = 4.0  # low cut frequency for filtering
high_cut_hz = 38.0  # high cut frequency for filtering
# Parameters for exponential moving standardization
factor_new = 1e-3
init_block_size = 1000

transforms = [
    Preprocessor("pick_types", eeg=True, meg=False, stim=False),  # Keep EEG sensors
    Preprocessor(
        lambda data, factor: np.multiply(data, factor),  # Convert from V to uV
        factor=1e6,
    ),
    Preprocessor("filter", l_freq=low_cut_hz, h_freq=high_cut_hz),  # Bandpass filter
    Preprocessor(
        exponential_moving_standardize,  # Exponential moving standardization
        factor_new=factor_new,
        init_block_size=init_block_size,
    ),
]

# Transform the data
preprocess(dataset, transforms, n_jobs=-1)


In [None]:
from braindecode.preprocessing import create_windows_from_events

trial_start_offset_seconds = -0.5
# Extract sampling frequency, check that they are same in all datasets
sfreq = dataset.datasets[0].raw.info["sfreq"]
assert all([ds.raw.info["sfreq"] == sfreq for ds in dataset.datasets])
# Calculate the trial start offset in samples.
trial_start_offset_samples = int(trial_start_offset_seconds * sfreq)

# Create windows using braindecode function for this. It needs parameters to define how
# trials should be used.
windows_dataset = create_windows_from_events(
    dataset,
    trial_start_offset_samples=trial_start_offset_samples,
    trial_stop_offset_samples=0,
    preload=True,
)

In [None]:
import collapsed_shallow_fbscp as c_shallow

In [None]:
import torch
#from shallow_fbcsp import ShallowFBCSPNet
from braindecode.util import set_random_seeds


cuda = torch.cuda.is_available()  # check if GPU is available, if True chooses to use it
device = "cuda" if cuda else "cpu"
if cuda:
    torch.backends.cudnn.benchmark = True
seed = 20200222
set_random_seeds(seed=seed, cuda=cuda)

n_classes = 4
classes = list(range(n_classes))
# Extract number of chans and time steps from dataset
n_channels = windows_dataset[0][0].shape[0]
input_window_samples = windows_dataset[0][0].shape[1]

print("n_classes: ", n_classes)
print("n_channels:", n_channels)
print("input_window_samples size:", input_window_samples)

In [None]:
#windows_dataset[0][0].shape

In [None]:
#!dir collapsed_shallow_fbscp

In [None]:
#from models_fbscp import CollapsedShallowNet
# The ShallowFBCSPNet is a `nn.Sequential` model
from shallow_laurits import ShallowFBCSPNet
model = ShallowFBCSPNet(
    n_chans=22,
    n_outputs=n_classes,
    n_times=input_window_samples,
    final_conv_length="auto",
)

# Display torchinfo table describing the model
print(model)

# Send model to GPU
if cuda:
    model.cuda()



In [None]:
splitted = windows_dataset.split("session")
train_set = splitted['0train']  # Session train
test_set = splitted['1test']  # Session evaluation

from torch.nn import Module
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader

#lr = 1e-4
#weight_decay = 1e-4
#batch_size = 64
#n_epochs = 200


In [None]:
#train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
#progress_bar = tqdm(enumerate(train_loader), total=len(train_loader))


#from collections import defaultdict

#counting_dict = defaultdict(int)  # Initialize class counter

#for batch_idx, (X, y, _) in progress_bar:
#   X, y = X.to(device), y.to(device)  # Move to device if needed
    
    # Count occurrences of each class in y
#    for value in y:
#        counting_dict[int(value.item())] += 1  # Convert tensor to int and update count

# Print class frequencies
#print("Class counts:", dict(counting_dict))


In [None]:

from tqdm import tqdm
# Define a method for training one epoch


def train_one_epoch(
        dataloader: DataLoader, model: Module, loss_fn, optimizer,
        scheduler: LRScheduler, epoch: int, device, print_batch_stats=True
):
    model.train()  # Set the model to training mode
    train_loss, correct = 0, 0

    progress_bar = tqdm(enumerate(dataloader), total=len(dataloader),
                        disable=not print_batch_stats)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        #print(X.shape)
        optimizer.zero_grad()
        pred = model(X)
        loss = loss_fn(pred, y)
        loss.backward()
        optimizer.step()  # update the model weights
        optimizer.zero_grad()

        train_loss += loss.item()
        correct += (pred.argmax(1) == y).sum().item()

        #if print_batch_stats:
        #    progress_bar.set_description(
        #        f"Epoch {epoch}/{n_epochs}, "
        #        f"Batch {batch_idx + 1}/{len(dataloader)}, "
        #        f"Loss: {loss.item():.6f}"
        #    )

    # Update the learning rate
    scheduler.step()

    correct /= len(dataloader.dataset)
    return train_loss / len(dataloader), correct


In [None]:
from collections import defaultdict
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix
import numpy as np

@torch.no_grad()
def test_model(dataloader: DataLoader, model: torch.nn.Module, loss_fn, print_batch_stats=True):
    device = next(model.parameters()).device  # Get model device
    size = len(dataloader.dataset)
    n_batches = len(dataloader)
    model.eval()  # Switch to evaluation mode
    test_loss, correct = 0, 0

    # Initialize dictionaries for per-class tracking
    class_correct = defaultdict(int)
    class_total = defaultdict(int)

    # Lists to store true and predicted labels for confusion matrix
    all_preds = []
    all_targets = []

    if print_batch_stats:
        progress_bar = tqdm(enumerate(dataloader), total=len(dataloader))
    else:
        progress_bar = enumerate(dataloader)

    for batch_idx, (X, y, _) in progress_bar:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        batch_loss = loss_fn(pred, y).item()

        test_loss += batch_loss
        correct += (pred.argmax(1) == y).sum().item()

        # Store predictions and true labels for confusion matrix
        all_preds.append(pred.argmax(1).cpu())
        all_targets.append(y.cpu())

        # Compute per-class accuracy
        preds_labels = pred.argmax(1)
        for label, pred_label in zip(y, preds_labels):
            class_total[label.item()] += 1
            class_correct[label.item()] += (label == pred_label).item()

        if print_batch_stats:
            progress_bar.set_description(
                f"Batch {batch_idx + 1}/{len(dataloader)}, Loss: {batch_loss:.6f}"
            )

    # Convert lists to tensors
    all_preds = torch.cat(all_preds)
    all_targets = torch.cat(all_targets)

    # Compute per-class accuracy
    class_accuracies = {
        cls: (class_correct[cls] / class_total[cls]) * 100 if class_total[cls] > 0 else 0
        for cls in class_total
    }

    # Compute overall accuracy
    test_loss /= n_batches
    overall_accuracy = (correct / size) * 100

    # Print per-class accuracy
    print("\nClass-wise Accuracy:")
    for cls, acc in class_accuracies.items():
        print(f"  Class {cls}: {acc:.2f}%")

    print(f"Test Accuracy: {overall_accuracy:.1f}%, Test Loss: {test_loss:.6f}\n")

    return test_loss, overall_accuracy, class_accuracies, all_preds, all_targets


In [None]:
from braindecode.models import ShallowFBCSPNet 

In [None]:
print(model)

In [None]:
model2 = ShallowFBCSPNet(
    n_chans=22,
    n_outputs=n_classes,
    n_times=input_window_samples,
    final_conv_length="auto",
)
print(model2)

In [None]:
import wandb
wandb.login()

In [None]:
import torch
import wandb
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.nn import CrossEntropyLoss
import numpy as np

# Initialize Weights & Biases
wandb.init(project="Master Thesis", name="Shallow Accuracy")

# Define hyperparameters
lr = 0.1
weight_decay = 1e-4
batch_size = 124  # Start with 124
n_epochs = 100

# Log hyperparameters to wandb
wandb.config.update({
    "learning_rate": lr,
    "weight_decay": weight_decay,
    "batch_size": batch_size,
    "epochs": n_epochs
})

# Define optimizer and scheduler
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs - 1)

# Define loss function
loss_fn = CrossEntropyLoss()

# Create DataLoaders
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_set, batch_size=batch_size)

# Initialize lists to store all predictions & targets
all_preds, all_targets = [], []

# Training loop
for epoch in range(1, n_epochs + 1):
    print(f"Epoch {epoch}/{n_epochs}: ", end="")

    train_loss, train_accuracy = train_one_epoch(
        train_loader, model, loss_fn, optimizer, scheduler, epoch, device
    )

    test_loss, test_accuracy, class_accuracies, batch_preds, batch_targets = test_model(test_loader, model, loss_fn)

    # Store predictions & labels for confusion matrix
    all_preds.extend(batch_preds)
    all_targets.extend(batch_targets)

    # Print class-wise accuracy
    print("\nClass-wise Accuracy:")
    for class_idx, acc in class_accuracies.items():
        print(f"  Class {class_idx}: {acc:.2f}%")

    # Log results to wandb
    wandb.log({
        "epoch": epoch,
        "train_loss": train_loss,
        "train_accuracy": train_accuracy * 100,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
        "learning_rate": scheduler.get_last_lr()[0],
        **{f"class_{class_idx}_accuracy": acc for class_idx, acc in class_accuracies.items()}
    })

    print(
        f"Train Accuracy: {100 * train_accuracy:.2f}%, "
        f"Average Train Loss: {train_loss:.6f}, "
        f"Test Accuracy: {test_accuracy:.2f}%, "
        f"Average Test Loss: {test_loss:.6f}\n"
    )

# Convert lists to NumPy arrays
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

# Save predictions & true labels for later use (confusion matrix)
wandb.log({"all_preds": all_preds.tolist(), "all_targets": all_targets.tolist()})

wandb.finish()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix

# Load predictions and true labels (ensure these are NumPy arrays)
all_preds = np.array(wandb.run.history(keys=["all_preds"])).flatten()
all_targets = np.array(wandb.run.history(keys=["all_targets"])).flatten()

# Compute confusion matrix
cm = confusion_matrix(all_targets, all_preds)

# Define class labels (modify if needed)
class_labels = [f"Class {i}" for i in range(cm.shape[0])]

# Plot confusion matrix using seaborn
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", xticklabels=class_labels, yticklabels=class_labels)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
plt.show()

In [None]:
import math
# Assuming 'model' is your trained Braindecode model
torch.save(model, "braindecode_model_temponly.pth")
torch.save(model.state_dict(), "braindecode_model_temponly_state.pth")
