In [27]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import Dataset


from torch.utils.data import DataLoader

from sklearn.model_selection import train_test_split

## Building the 1D-CNN with residual connections.

In [28]:
class ResidualBlock1D(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()

        # First conv layer (uses passed in in_channels, NOT hard-coded)
        self.conv1 = nn.Conv1d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=stride,
            padding=1
        )
        self.bn1 = nn.BatchNorm1d(out_channels)

        # Second conv layer
        self.conv2 = nn.Conv1d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.bn2 = nn.BatchNorm1d(out_channels)

        # Downsample (skip connection) if needed
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv1d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=1,
                    stride=stride
                ),
                nn.BatchNorm1d(out_channels)
            )
        else:
            self.downsample = None

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        # Apply skip path if needed
        if self.downsample is not None:
            identity = self.downsample(identity)

        out = out + identity
        out = self.relu(out)
        return out


In [29]:
class ResNet1D(nn.Module):
    def __init__(self, n_channels, n_classes):
        super().__init__()
        
        # Initial convolution "stem"
        self.conv1 = nn.Conv1d(
            in_channels=n_channels,
            out_channels=64,
            kernel_size=7,
            stride=2,
            padding=3
        )
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
        
        # Residual stages
        self.layer1 = self._make_layer(64, 64, num_blocks=2, stride=1)
        self.layer2 = self._make_layer(64, 128, num_blocks=2, stride=2)
        self.layer3 = self._make_layer(128, 256, num_blocks=2, stride=2)
        self.layer4 = self._make_layer(256, 512, num_blocks=2, stride=2)
        
        # Global average pooling over time dimension
        self.global_pool = nn.AdaptiveAvgPool1d(1)  # output: (batch, channels, 1)
        
        # Final classifier
        self.fc = nn.Linear(512, n_classes)

    def _make_layer(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        # First block may change channels/stride
        layers.append(ResidualBlock1D(in_channels, out_channels, stride=stride))
        # Remaining blocks keep same channels/stride=1
        for _ in range(1, num_blocks):
            layers.append(ResidualBlock1D(out_channels, out_channels, stride=1))
        return nn.Sequential(*layers)

    def forward(self, x):
        # x shape: (batch, n_channels, n_times)
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        
        # Residual layers
        x = self.layer1(x)  # shape: (batch, 64, T1)
        x = self.layer2(x)  # shape: (batch, 128, T2)
        x = self.layer3(x)  # shape: (batch, 256, T3)
        x = self.layer4(x)  # shape: (batch, 512, T4)
        
        # Global average pooling: average over time dimension
        x = self.global_pool(x)  # (batch, 512, 1)
        x = x.squeeze(-1)        # (batch, 512)
        
        # Classifier
        logits = self.fc(x)      # (batch, n_classes)
        return logits


## Building the training Loop

In [30]:
#Dataset class

class EEGDataset(Dataset):
    def __init__(self, X, y):
        self.X = X.float()
        self.y = y.long()

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

In [31]:
# split X and Y based on subjects 
X_raw = np.load("../data/X_tqwt_wpd.npy")
y_raw = np.load("../data/y_labels.npy")

subject_ids = np.load("../data/subject_ids.npy", allow_pickle=True)
unique_subs = np.unique(subject_ids)

# Not required 
# print("Number of subjects:", len(unique_subs))
# print("Subject IDs:", unique_subs)


In [32]:
train_subs, val_subs = train_test_split(
    unique_subs,
    test_size=0.2,
    random_state=42,
    shuffle=True
)

In [33]:
# Basically ensures that there is no leakage 

train_mask = np.isin(subject_ids, train_subs)
val_mask = np.isin(subject_ids, val_subs)

In [34]:
X = np.load("../data/X_tqwt_wpd.npy")
y = np.load("../data/y_labels.npy")

In [35]:
print("X shape:", X.shape)
print("y shape:", y.shape)
print("subject_ids shape:", subject_ids.shape)

X shape: (16749, 1140)
y shape: (16749,)
subject_ids shape: (16749,)


In [36]:
# convert to torch tensors
X1 = X[train_mask]
y1 = y[train_mask]

X_val_np = X[val_mask]
y_val_np = y[val_mask]


X_train_np = torch.tensor(X1, dtype=torch.float32)
y_train_np = torch.tensor(y1, dtype=torch.long)

X_val_np = torch.tensor(X_val_np, dtype=torch.float32)
y_val_np = torch.tensor(y_val_np, dtype=torch.long)

### Lets save the files once... No need to do this if you already have the files downloaded

In [37]:
# np.save("../data/X_train.npy", X_train_np)
# np.save("../data/y_train.npy", y_train_np)
# np.save("../data/X_val.npy", X_val_np)
# np.save("../data/y_val.npy", y_val_np)

In [38]:
# even though they were converted to tensors before saving they need to be converted back becsuse when loading them with np it becomes an array again.
X_train = np.load("../data/X_train.npy")
y_train = np.load("../data/y_train.npy")
X_val   = np.load("../data/X_val.npy")
y_val   = np.load("../data/y_val.npy")


print("Before reshaping input to DataLoader:")
print(X_train.shape)
print("Raw X_val shape:",   X_val.shape)

X_train = X_train.reshape(len(X_train), 1, -1)
X_val   = X_val.reshape(len(X_val), 1, -1)

print("after reshaping input to DataLoader:")
print(X_train.shape)
print("Raw X_val shape:",   X_val.shape)


X_train = torch.tensor(X_train, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.long)
X_val = torch.tensor(X_val, dtype=torch.float32)
y_val = torch.tensor(y_val, dtype=torch.long)


Before reshaping input to DataLoader:
(13305, 1140)
Raw X_val shape: (3444, 1140)
after reshaping input to DataLoader:
(13305, 1, 1140)
Raw X_val shape: (3444, 1, 1140)


In [39]:
#Creating Dataloaders 

train_dataset = EEGDataset(X_train, y_train)
train_loader  = DataLoader(train_dataset, batch_size=32, shuffle=True)
                           
val_dataset = EEGDataset(X_val, y_val)
val_loader   = DataLoader(val_dataset,   batch_size=32, shuffle=False)

In [40]:
# Safety Check
for xb, yb in train_loader:
    print("TRAIN BATCH SHAPE:", xb.shape)
    break

for xb, yb in val_loader:
    print("VAL BATCH SHAPE:", xb.shape)
    break

print("X_train shape from file:", X_train.shape)
print("X_val shape from file:", X_val.shape)

print("First element type:", type(X_train[0]))
print("First element shape:", getattr(X_train[0], 'shape', None))

TRAIN BATCH SHAPE: torch.Size([32, 1, 1140])
VAL BATCH SHAPE: torch.Size([32, 1, 1140])
X_train shape from file: torch.Size([13305, 1, 1140])
X_val shape from file: torch.Size([3444, 1, 1140])
First element type: <class 'torch.Tensor'>
First element shape: torch.Size([1, 1140])


In [41]:
#initialze the CNN

n_channels = 1      # or however many EEG channels we have
n_classes = 2        # ADHD vs Control

model = ResNet1D(n_channels=n_channels, n_classes=n_classes)

In [42]:
#letting the computer know what piece of hardware to run the training 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

## Helper functions


In [43]:
# Accuracy calculations

def accuracy_from_logits(logits, y):
    preds = torch.argmax(logits, dim=1)
    correct = (preds == y).sum().item()
    total = y.size(0)
    return correct / total

In [44]:
#Training loop

def train_one_epoch(model, train_loader, criterion, optimizer, device):
    model.train()  # put model in "training mode"
    running_loss = 0.0
    correct = 0
    total = 0
    
    for X, y in train_loader:
        print("TRAIN LOOP SHAPE:", X.shape)
        break
        
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)

        # 1. Forward pass
        logits = model(X)

        # 2. Compute loss
        loss = criterion(logits, y)

        # 3. Zero out old gradients
        optimizer.zero_grad()

        # 4. Compute gradients
        loss.backward()

        # 5. Update weights
        optimizer.step()

        # Track training accuracy & loss
        running_loss += loss.item() * X.size(0)

        _, predicted = torch.max(logits, dim=1)
        correct += (predicted == y).sum().item()
        total += y.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total

    return epoch_loss, epoch_acc


#only run to check if functional. and last i ran it was functional
#train_one_epoch(model, train_loader, criterion, optimizer, device)

In [45]:
#Valid loop with no gradient Updates 
# Very similar to the training loop, except this one sets the model to eval and its accompanied by other

def validate(model, val_loader, criterion, device):
    model.eval()  # evaluation mode
    running_loss = 0.0
    correct = 0
    total = 0

    # Do NOT track gradients
    with torch.no_grad():
        for X, y in val_loader:
            print("VAL LOOP SHAPE:", X.shape)
            break
        for X, y in val_loader:
            X, y = X.to(device), y.to(device)

            logits = model(X)
            loss = criterion(logits, y)

            running_loss += loss.item() * X.size(0)
            _, predicted = torch.max(logits, dim=1)
            correct += (predicted == y).sum().item()
            total += y.size(0)

    val_loss = running_loss / total
    val_acc = correct / total

    return val_loss, val_acc

#only run to check if functional. and last i ran it was functional
#validate(model, val_loader, criterion, device)

In [46]:
def train_model(model, train_loader, val_loader, device, epochs=20, lr=1e-3):

    history = {
        "train_loss": [],
        "train_acc":  [],
        "val_loss":   [],
        "val_acc":    []
    }

    for epoch in range(epochs):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )

        val_loss, val_acc = validate(
            model, val_loader, criterion, device
        )

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"Train Loss: {train_loss:.4f} Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}"
        )

    return history


## Training the architectiure


In [47]:
#Training the resnet architeciure on Raw data 

resnet_raw = ResNet1D( n_channels=X_train.shape[1] , n_classes=2).to(device)

history_resnet_raw = train_model(
    model=resnet_raw,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=20,
    lr=1e-3
)

# Saving for Saliency later
torch.save(resnet_raw.state_dict(), "resnet_eeg.pth")
# Can be loaded with the following code 
# model.load_state_dict(torch.load("resnet_eeg.pth"))

TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])


KeyboardInterrupt: 

## Compute saliency 

In [48]:
def compute_saliency(model, X_batch, y_batch, device):
    model.eval()
    
    X_batch = X_batch.to(device)
    y_batch = y_batch.to(device)
    X_batch.requires_grad = True  # Enable gradient wrt input

    # Forward pass
    logits = model(X_batch)
    loss = F.cross_entropy(logits, y_batch)

    # Backward pass
    loss.backward()

    # Gradient wrt input
    saliency = X_batch.grad.detach().abs()  # absolute gradient

    return saliency.cpu().numpy()


## Global saliency vector 

In [49]:
all_saliencies = []

for Xb, yb in val_loader:
    sal = compute_saliency(model, Xb, yb, device)
    all_saliencies.append(sal)

all_saliencies = np.concatenate(all_saliencies, axis=0)   # (N_val, 1, 1140)

# Average across samples and channel
global_saliency = all_saliencies.mean(axis=0).squeeze(0)  # (1140,)


In [50]:
#Normalize
global_saliency = global_saliency / global_saliency.max()

In [51]:
# Ceating a binary mask 
threshold = 0.2  # keep top 80% important features or however much we wanna keep
mask = (global_saliency > threshold).astype(np.float32) # does this not return a boolean ?? 


print("Mask shape:", mask.shape)  # (1140,)

Mask shape: (1140,)


## Handling the Masked data 

In [52]:
#Applying saliency mask to the dataset, actually just uses simple multiplication

# Convert tensors back to numpy so we can multiply with mask
X_train_np = X_train.cpu().numpy().reshape(len(X_train), 1140)
X_val_np   = X_val.cpu().numpy().reshape(len(X_val), 1140)

#applying the mask
X_train_masked = X_train_np * mask
X_val_masked   = X_val_np * mask

#reshaping back to conv1d format 
X_train_masked = X_train_masked.reshape(len(X_train_masked), 1, 1140)
X_val_masked   = X_val_masked.reshape(len(X_val_masked),   1, 1140)

In [55]:
# Creating the data loaders for the masked data  
y_train_masked = y_train.clone()
y_val_masked   = y_val.clone()

X_train_masked = torch.tensor(X_train_masked, dtype=torch.float32)
X_val_masked   = torch.tensor(X_val_masked,   dtype=torch.float32)

#creating the masked dataset
train_dataset_masked = EEGDataset(X_train_masked, y_train_masked)
val_dataset_masked   = EEGDataset(X_val_masked,   y_val_masked)

#dataLoaders
train_loader_masked = DataLoader(train_dataset_masked, batch_size=32, shuffle=True)
val_loader_masked   = DataLoader(val_dataset_masked,   batch_size=32, shuffle=False)


In [56]:
# converting masked data to tensors 
X_train_masked = torch.tensor(X_train_masked, dtype=torch.float32)
y_train_masked = y_train.clone()      # same labels
X_val_masked   = torch.tensor(X_val_masked,   dtype=torch.float32)
y_val_masked   = y_val.clone()

  X_train_masked = torch.tensor(X_train_masked, dtype=torch.float32)
  X_val_masked   = torch.tensor(X_val_masked,   dtype=torch.float32)


In [57]:
# creatging new dataset and dataloaders for the data 
train_dataset_masked = EEGDataset(X_train_masked, y_train_masked)
val_dataset_masked   = EEGDataset(X_val_masked,   y_val_masked)

train_loader_masked = DataLoader(train_dataset_masked, batch_size=32, shuffle=True)
val_loader_masked   = DataLoader(val_dataset_masked,   batch_size=32, shuffle=False)


In [58]:
# Checking shapes for troubleshooting purposes 
for xb, yb in train_loader_masked:
    print("MASKED TRAIN BATCH:", xb.shape)
    break

for xb, yb in val_loader_masked:
    print("MASKED VAL BATCH:", xb.shape)
    break


MASKED TRAIN BATCH: torch.Size([32, 1, 1140])
MASKED VAL BATCH: torch.Size([32, 1, 1140])


## 1Dimentional implemeentation 

In [59]:
# Eeg net 1Dimentional implemeentation 

class EEGNet1D(nn.Module):
    def __init__(
        self,
        n_classes: int = 2,
        Chans: int = 1,        # number of input channels
        Samples: int = 1140,   # length of the time series
        F1: int = 8,
        D: int = 2,
        kernel_length: int = 64,
        dropout: float = 0.25
    ):
        super().__init__()
        self.n_classes = n_classes
        self.Chans = Chans
        self.Samples = Samples

        # 1) Temporal convolution
        self.conv_temporal = nn.Conv1d(
            in_channels=Chans,
            out_channels=F1,
            kernel_size=kernel_length,
            padding=kernel_length // 2,
            bias=False
        )
        self.bn1 = nn.BatchNorm1d(F1)

        # 2) Depthwise convolution
        #    Each filter operates on its own channel (groups=F1), multiplied by D
        self.conv_depthwise = nn.Conv1d(
            in_channels=F1,
            out_channels=F1 * D,
            kernel_size=kernel_length,
            padding=kernel_length // 2,
            groups=F1,
            bias=False
        )
        self.bn2 = nn.BatchNorm1d(F1 * D)
        self.pool1 = nn.AvgPool1d(kernel_size=4)
        self.dropout1 = nn.Dropout(dropout)

        # 3) Separable convolution
        #    depthwise (groups=F1*D) + pointwise (1x1 conv)
        self.conv_separable_depth = nn.Conv1d(
            in_channels=F1 * D,
            out_channels=F1 * D,
            kernel_size=16,
            padding=16 // 2,
            groups=F1 * D,
            bias=False
        )
        self.conv_separable_point = nn.Conv1d(
            in_channels=F1 * D,
            out_channels=F1 * D * 2,  # F2 = 2 * F1 * D
            kernel_size=1,
            bias=False
        )
        self.bn3 = nn.BatchNorm1d(F1 * D * 2)
        self.pool2 = nn.AvgPool1d(kernel_size=8)
        self.dropout2 = nn.Dropout(dropout)

        # 4) Global average pooling + classifier
        self.global_pool = nn.AdaptiveAvgPool1d(1)
        self.classifier = nn.Linear(F1 * D * 2, n_classes)

        self.elu = nn.ELU()

    def forward(self, x):
        """
        x: (batch_size, 1, 1140)
        """
        # 1) Temporal conv
        x = self.conv_temporal(x)
        x = self.bn1(x)
        x = self.elu(x)

        # 2) Depthwise conv
        x = self.conv_depthwise(x)
        x = self.bn2(x)
        x = self.elu(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        # 3) Separable conv
        x = self.conv_separable_depth(x)
        x = self.conv_separable_point(x)
        x = self.bn3(x)
        x = self.elu(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        # 4) Global pooling + classifier
        x = self.global_pool(x)     # (batch, channels, 1)
        x = x.squeeze(-1)           # (batch, channels)
        logits = self.classifier(x) # (batch, n_classes)
        return logits

## Training both datasets 

In [62]:
#Train EEgnet on raw data 
eegnet_raw = EEGNet1D(n_classes= 2, Chans=1, Samples=1140).to(device)

history_raw = train_model(
    model=eegnet_raw,
    train_loader=train_loader,
    val_loader=val_loader,
    device=device,
    epochs=20,
    lr=1e-3
)

# Saving for Saliency later
torch.save(history_raw.state_dict(), "raw_EEGNet.pth")
# Can be loaded with the following code 
# model.load_state_dict(torch.load("raw_EEGNet.pth"))

TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 1/20 | Train Loss: 0.6920 Acc: 0.5440 | Val Loss: 0.6794 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 2/20 | Train Loss: 0.6922 Acc: 0.5439 | Val Loss: 0.6784 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 3/20 | Train Loss: 0.6924 Acc: 0.5433 | Val Loss: 0.6786 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 4/20 | Train Loss: 0.6918 Acc: 0.5442 | Val Loss: 0.6789 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 5/20 | Train Loss: 0.6922 Acc: 0.5436 | Val Loss: 0.6785 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])
VAL LOOP SHAPE: torch.Size([32, 1, 1140])
Epoch 6/20 | Train Loss: 0.6923 Acc: 0.5439 | Val Loss: 0.6779 Acc: 0.6115
TRAIN LOOP SHAPE: torch.Size([32, 

KeyboardInterrupt: 

In [61]:
#Train EEg net on masked data 
eegnet_masked = EEGNet1D(n_classes= 2, Chans=1, Samples=1140).to(device)

history_masked = train_model(
    model=eegnet_masked,
    train_loader=train_loader_masked,
    val_loader=val_loader_masked,
    device=device,
    epochs=20,
    lr=1e-3
)

# Saving for Saliency later
torch.save(history_masked.state_dict(), "masked_EEGNet.pth")
# Can be loaded with the following code 
# model.load_state_dict(torch.load("masked_EEGNet.pth"))

TRAIN LOOP SHAPE: torch.Size([32, 1, 1140])


IndexError: Target 1 is out of bounds.

## Simple evaluation

In [66]:
resnet_loaded = ResNet1D(n_channels=1, n_classes=2).to(device)
resnet_loaded.load_state_dict(torch.load("resnet_eeg.pth", map_location=device))
resnet_loaded.eval()

ResNet1D(
  (conv1): Conv1d(1, 64, kernel_size=(7,), stride=(2,), padding=(3,))
  (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool1d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): ResidualBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn2): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
    )
    (1): ResidualBlock1D(
      (conv1): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn1): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv1d(64, 64, kernel_size=(3,), stride=(1,), padding=(1,))
      (bn2): BatchNo

In [67]:
def evaluate_model(model, val_loader, device):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for X, y in val_loader:
            X = X.to(device)
            y = y.to(device)

            logits = model(X)
            preds = torch.argmax(logits, dim=1)

            correct += (preds == y).sum().item()
            total   += y.size(0)

    return correct / total

In [68]:
acc_resnet = evaluate_model(resnet_loaded, val_loader, device)
acc_eegnet_raw = evaluate_model(eegnet_raw_loaded, val_loader, device)
acc_eegnet_masked = evaluate_model(eegnet_masked_loaded, val_loader_masked, device)

In [69]:
print("Accuracy acore for the resnet arch", acc_resnet)
print("Accuracy acore for the Raw eegnet data",acc_eegnet_raw)
print("Accuracy acore for the masked eegnet data",acc_eegnet_masked)

0.40534262485482
