In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from shallow import SubjectDicionaryFCNet

# Initialize the model
model = SubjectDicionaryFCNet(n_chans=22, n_outputs=4)

# Set all weights and biases of fc_layers to ones for consistency
for fc_layer in model.fc_layers.values():
    fc_layer.weight.data.fill_(1.0)
    fc_layer.bias.data.fill_(1.0)

# First, set requires_grad=False for all parameters
for param in model.parameters():
    param.requires_grad = False

# Then, set requires_grad=True for parameters in fc_layers
for fc_layer in model.fc_layers.values():
    for param in fc_layer.parameters():
        param.requires_grad = True

# Verify requires_grad status
print("Parameter requires_grad status:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Copy initial weights and biases
initial_weights = {}
initial_biases = {}
for subject_id, fc_layer in model.fc_layers.items():
    initial_weights[subject_id] = fc_layer.weight.data.clone()
    initial_biases[subject_id] = fc_layer.bias.data.clone()

# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)


# Define target outputs
n_outputs = model.fc_layers['subject_1'].out_features
target = torch.zeros(batch_size, n_outputs)

# Use a non-zero target for one specific subject
specific_subject_id = 3  # Subject ID for which we want to induce updates
#fill with 1's
batch[specific_subject_id - 1] = 1
target[specific_subject_id - 1] = torch.tensor([1.0, 2.0, 3.0, 4.0])  # Arbitrary non-zero target

# Set the last time point to contain the subject IDs (from 1 to batch_size)
for i in range(batch_size):
    subject_id = i + 1  # Subject IDs from 1 to 5  
    batch[i, :, -1] = 3 * 1000000

# Define a loss function
criterion = nn.MSELoss()

# Define an optimizer (only parameters with requires_grad=True will be updated)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()

# Zero the gradients
optimizer.zero_grad()

# Forward pass
output = model(batch)

# Compute loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Verify gradients before optimizer step
print("\nGradients before optimizer step:")
for subject_id, fc_layer in model.fc_layers.items():
    grad_norm = fc_layer.weight.grad.norm().item() if fc_layer.weight.grad is not None else 0
    print(f"Gradient norm for {subject_id}: {grad_norm}")

# Update weights
optimizer.step()

# Check which fc_layers have had their weights updated
for subject_id, fc_layer in model.fc_layers.items():
    weights_changed = (fc_layer.weight.data != initial_weights[subject_id]).float().sum().item()
    biases_changed = (fc_layer.bias.data != initial_biases[subject_id]).float().sum().item()
    print(f"\nWeights changed for {subject_id}: {weights_changed}")
    print(f"Biases changed for {subject_id}: {biases_changed}")

# Assertions to verify subject specificity
for subject_id, fc_layer in model.fc_layers.items():
    weights_changed = (fc_layer.weight.data != initial_weights[subject_id]).float().sum().item()
    biases_changed = (fc_layer.bias.data != initial_biases[subject_id]).float().sum().item()
    if subject_id == f'subject_{specific_subject_id}':
        assert weights_changed > 0, f"Weights for {subject_id} did not change!"
        assert biases_changed > 0, f"Biases for {subject_id} did not change!"
    else:
        assert weights_changed == 0, f"Weights for {subject_id} have changed!"
        assert biases_changed == 0, f"Biases for {subject_id} have changed!"

print("\nTest passed: Only the weights and biases corresponding to the specific subject have changed.")


Parameter requires_grad status:
spatio_temporal.weight: requires_grad=False
spatio_temporal.bias: requires_grad=False
batch_norm.weight: requires_grad=False
batch_norm.bias: requires_grad=False
fc_layers.subject_1.weight: requires_grad=True
fc_layers.subject_1.bias: requires_grad=True
fc_layers.subject_2.weight: requires_grad=True
fc_layers.subject_2.bias: requires_grad=True
fc_layers.subject_3.weight: requires_grad=True
fc_layers.subject_3.bias: requires_grad=True
fc_layers.subject_4.weight: requires_grad=True
fc_layers.subject_4.bias: requires_grad=True
fc_layers.subject_5.weight: requires_grad=True
fc_layers.subject_5.bias: requires_grad=True
fc_layers.subject_6.weight: requires_grad=True
fc_layers.subject_6.bias: requires_grad=True
fc_layers.subject_7.weight: requires_grad=True
fc_layers.subject_7.bias: requires_grad=True
fc_layers.subject_8.weight: requires_grad=True
fc_layers.subject_8.bias: requires_grad=True
fc_layers.subject_9.weight: requires_grad=True
fc_layers.subject_9.bia

In [11]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from shallow import CollapsedShallowNet

# Initialize the model
model = CollapsedShallowNet(n_chans=22, n_outputs=4)

fc_layer = model.fc
# Set all weights and biases of fc to ones for consistency

dummy_input = torch.zeros(1, 22, 1001)  # Adjust shape as per model requirements
model(dummy_input)

nn.init.constant_(fc_layer.weight, 1.0)
nn.init.constant_(fc_layer.bias, 1.0)

# First, set requires_grad=False for all parameters
for param in model.parameters():
    param.requires_grad = False

# Then, set requires_grad=True for parameters in fc
for param in fc_layer.parameters():
    param.requires_grad = True
    
# Verify requires_grad status
print("Parameter requires_grad status:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Copy initial weights and biases


initial_weights = {}
initial_biases = {}

initial_weights[1] = fc_layer.weight.data.clone()
initial_biases[1] = fc_layer.bias.data.clone()


# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)

# Define target outputs
n_outputs = fc_layer.out_features
target = torch.zeros(batch_size, n_outputs)

# Use a non-zero target for one specific subject
specific_subject_id = 3  # Subject ID for which we want to induce updates
#fill with 1's
batch[specific_subject_id - 1] = 1
target[specific_subject_id - 1] = torch.tensor([1.0, 2.0, 3.0, 4.0])  # Arbitrary non-zero target

# Set the last time point to contain the subject IDs (from 1 to batch_size)
for i in range(batch_size):
    subject_id = i + 1  # Subject IDs from 1 to 5
    batch[i, :, -1] = 3 * 1000000

# Define a loss function
criterion = nn.MSELoss()

# Define an optimizer (only parameters with requires_grad=True will be updated)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()

# Zero the gradients
optimizer.zero_grad()

# Forward pass
output = model(batch)

# Compute loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Verify gradients before optimizer step
print("\nGradients before optimizer step:")

grad_norm = fc_layer.weight.grad.norm().item() if fc_layer.weight.grad is not None else 0
print(f"Gradient norm for fc: {grad_norm}")

# Update weights
optimizer.step()

# check total amount of weights
total_weights = fc_layer.weight.data.numel()
total_biases = fc_layer.bias.data.numel()
print(f"\nTotal number of weights: {total_weights}")
print(f"Total number of biases: {total_biases}")


#check how many weights have changed
weights_changed = (fc_layer.weight.data != initial_weights[1]).float().sum().item()
biases_changed = (fc_layer.bias.data != initial_biases[1]).float().sum().item()
print(f"\nWeights changed: {weights_changed}")
print(f"Biases changed: {biases_changed}")



# Assertions to verify that all weights and biases have changed
assert weights_changed > total_weights-100, "Not all weights have changed!"
assert biases_changed == total_biases, "Not all biases have changed!"

print(fc_layer.weight.data[:8])
    



Parameter requires_grad status:
spatio_temporal.weight: requires_grad=False
spatio_temporal.bias: requires_grad=False
batch_norm.weight: requires_grad=False
batch_norm.bias: requires_grad=False
fc.weight: requires_grad=True
fc.bias: requires_grad=True

Gradients before optimizer step:
Gradient norm for fc: 654.693115234375

Total number of weights: 1440
Total number of biases: 4

Weights changed: 1408.0
Biases changed: 4.0
tensor([[0.9957, 1.0033, 0.9962,  ..., 1.0023, 0.9964, 0.9964],
        [0.9956, 1.0033, 0.9961,  ..., 1.0023, 0.9963, 0.9963],
        [0.9956, 1.0033, 0.9961,  ..., 1.0023, 0.9963, 0.9962],
        [0.9955, 1.0033, 0.9960,  ..., 1.0023, 0.9962, 0.9962]])




In [351]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from shallow import SubjectOneHotNet

# Initialize the model
model = SubjectOneHotNet(n_chans=22, n_outputs=4)

# Set all weights and biases of fc_layers to ones for consistency
model.fc_shared.weight.data.fill_(1.0)
model.fc_shared.bias.data.fill_(1.0)

# Set requires_grad=False for all parameters except the fc_shared layer
for param in model.parameters():
    param.requires_grad = False
for param in model.fc_shared.parameters():
    param.requires_grad = True

# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001
n_outputs = model.n_outputs

# Initialize weights and biases tracking for multiple iterations
initial_weights = model.fc_shared.weight.data.clone()
initial_biases = model.fc_shared.bias.data.clone()
weights_change_history = []
biases_change_history = []

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)
target = torch.zeros(batch_size, n_outputs)

# Choose the specific subject to update
specific_subject_id = 2  # Subject ID to test
batch[specific_subject_id - 1] = 1

# Assign subject IDs to the last time point
for i in range(batch_size):
    subject_id = i + 1
    if i == specific_subject_id - 1:
        batch[i, :, -1] = 9 * 1000000
    else:
        batch[i, :, -1] = 2 * 1000000

# Define loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Hook to zero gradients for non-targeted subjects
# def gradient_hook(grad):
#     mask = torch.zeros_like(grad)
#     start_idx = (specific_subject_id - 1) * n_outputs
#     end_idx = start_idx + n_outputs
#     mask[start_idx:end_idx, :] = 1
#     return grad * mask

# model.fc_shared.weight.register_hook(gradient_hook)

# Run multiple training iterations
iterations = 10  # Define the number of training steps
for i in range(iterations):
    # Training step
    model.train()
    optimizer.zero_grad()
    output = model(batch)
    loss = criterion(output, target)
    loss.backward()

    # Record gradients and update
    optimizer.step()

    # Track changed weights and biases
    weights_changed = (model.fc_shared.weight.data != initial_weights).float().sum().item()
    biases_changed = (model.fc_shared.bias.data != initial_biases).float().sum().item()
    weights_change_history.append(weights_changed)
    biases_change_history.append(biases_changed)
    batch[specific_subject_id - 1] = 10 ** i
    batch[1, :, -1] = 2 * 1000000
    #batch[specific_subject_id - 1] = 10 *i

    print(f"Iteration {i + 1}: Weights changed: {weights_changed}, Biases changed: {biases_changed}")

# Display final weight change distribution
print("\nFinal weights and biases after multiple iterations:")
for i in range(36):
    if i % 4 == 0:
        print(model.fc_shared.weight.data[i][:8])


Iteration 1: Weights changed: 2048.0, Biases changed: 8.0
Iteration 2: Weights changed: 2172.0, Biases changed: 8.0
Iteration 3: Weights changed: 2184.0, Biases changed: 8.0
Iteration 4: Weights changed: 2184.0, Biases changed: 8.0
Iteration 5: Weights changed: 2184.0, Biases changed: 8.0
Iteration 6: Weights changed: 2184.0, Biases changed: 8.0
Iteration 7: Weights changed: 2184.0, Biases changed: 8.0
Iteration 8: Weights changed: 2184.0, Biases changed: 8.0
Iteration 9: Weights changed: 2184.0, Biases changed: 8.0
Iteration 10: Weights changed: 2184.0, Biases changed: 8.0

Final weights and biases after multiple iterations:
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1.0836, 0.4521, 0.4649, 0.4327, 0.8600, 0.8812, 0.4191, 0.6643])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from shallow import SubjectAdvIndexFCNet

# Initialize the model
model = SubjectAdvIndexFCNet(n_chans=22, n_outputs=4)

# Set all weights and biases of fc_layers to ones for consistency

model.fc_shared.weight.data.fill_(1.0)
model.fc_shared.bias.data.fill_(1.0)

# First, set requires_grad=False for all parameters
for param in model.parameters():
    param.requires_grad = False

# Then, set requires_grad=True for parameters in fc_layers
for param in model.fc_shared.parameters():
    param.requires_grad = True
    
# Verify requires_grad status
print("Parameter requires_grad status:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Copy initial weights and biases

initial_weights = {}
initial_biases = {}

initial_weights[1] = model.fc_shared.weight.data.clone()
initial_biases[1] = model.fc_shared.bias.data.clone()


# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)

# Correct the n_outputs
n_outputs = model.n_outputs  # This ensures n_outputs is 4
target = torch.zeros(batch_size, n_outputs)

# Use a non-zero target for one specific subject
specific_subject_id = 5  # Subject ID for which we want to induce updates
batch[specific_subject_id - 1] = -90000
#batch[specific_subject_id] = 1
target[specific_subject_id - 1] = torch.tensor([1.0, 2.0, 3.0, 4.0])  # Now this works


# Set the last time point to contain the subject IDs (from 1 to batch_size)
for i in range(batch_size):
    subject_id = i + 1 
    if i == specific_subject_id - 1:
        batch[i, :, -1] = 4 * 1000000
    else: batch[i, :, -1] = specific_subject_id * 1000000

# Define a loss function
criterion = nn.MSELoss()

# Define an optimizer (only parameters with requires_grad=True will be updated)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()

# Zero the gradients
optimizer.zero_grad()

# Forward pass
output = model(batch)

# Compute loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Verify gradients before optimizer step
print("\nGradients before optimizer step:")

grad_norm = model.fc_shared.weight.grad.norm().item() if model.fc_shared.weight.grad is not None else 0
print(f"Gradient norm for fc: {grad_norm}")

# Update weights
optimizer.step()

# check total amount of weights
total_weights = model.fc_shared.weight.data.numel()
total_biases = model.fc_shared.bias.data.numel()
print(f"\nTotal number of weights: {total_weights}")
print(f"Total number of biases: {total_biases}")


#check how many weights have changed
weights_changed = (model.fc_shared.weight.data != initial_weights[1]).float().sum().item()
biases_changed = (model.fc_shared.bias.data != initial_biases[1]).float().sum().item()
print(f"\nWeights changed: {weights_changed}")
print(f"Biases changed: {biases_changed}")


# Assertions to verify that all weights and biases have changed
assert (weights_changed < (total_weights / 9)*3) and  (weights_changed > (total_weights / 9)), "The interval of weights changed is not correct!"

# exactly8 biases are changed

assert biases_changed == 8, "The number of biases does not correspond!"



# Running with 2 different subjects, I expect arounf 2 / 9 of the weights to change as well as 8 / 36 biases


for i in range(36):
    if i % 4 == 0:
        print(model.fc_shared.weight.data[i][:8])


Parameter requires_grad status:
spatio_temporal.weight: requires_grad=False
spatio_temporal.bias: requires_grad=False
batch_norm.weight: requires_grad=False
batch_norm.bias: requires_grad=False
fc_shared.weight: requires_grad=True
fc_shared.bias: requires_grad=True

Gradients before optimizer step:
Gradient norm for fc: 205.77197265625

Total number of weights: 12960
Total number of biases: 36

Weights changed: 2072.0
Biases changed: 8.0
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([0.9200, 1.0000, 1.0000, 0.9200, 1.0000, 1.0000, 0.9200, 0.9200])
tensor([0.9770, 0.9830, 0.9830, 0.9770, 0.9940, 0.9890, 0.9830, 0.9940])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([1., 1., 1., 1., 1., 1., 1., 1.])


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from shallow import CollapsedShallowNet

# Initialize the model
model = CollapsedShallowNet(n_chans=22, n_outputs=4)

# Get the spatio_temporal layer
spatio_temporal_layer = model.spatio_temporal

# Set all weights and biases of spatio_temporal_layer to ones for consistency
nn.init.constant_(spatio_temporal_layer.weight, 1.0)
if spatio_temporal_layer.bias is not None:
    nn.init.constant_(spatio_temporal_layer.bias, 1.0)

# First, set requires_grad=False for all parameters
for param in model.parameters():
    param.requires_grad = False

# Then, set requires_grad=True for parameters in spatio_temporal_layer
for param in spatio_temporal_layer.parameters():
    param.requires_grad = True
    
for param in model.batch_norm.parameters():
    param.requires_grad = True

# Verify requires_grad status
print("Parameter requires_grad status:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Copy initial weights and biases
initial_weights = {}
initial_biases = {}
initial_weights[1] = spatio_temporal_layer.weight.data.clone()
if spatio_temporal_layer.bias is not None:
    initial_biases[1] = spatio_temporal_layer.bias.data.clone()

# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)

# Use a non-zero input for one specific subject to induce updates
specific_subject_id = 3  # Subject ID for which we want to induce updates


#make each channel of the specific subject 1 be different numbers
for i in range(n_chans):
    batch[specific_subject_id - 1, i] = i + 1000


# Define target outputs
n_outputs = model.fc.out_features
target = torch.zeros(batch_size, n_outputs)
target[specific_subject_id - 1] = torch.tensor([1.0, 2.0, 3.0, 4.0])  # Arbitrary non-zero target

# Set the last time point to contain the subject IDs (from 1 to batch_size)
for i in range(batch_size):
    subject_id = i + 1  # Subject IDs from 1 to 5
    batch[i, :, -1] = 3 * 1000000

# Define a loss function
criterion = nn.CrossEntropyLoss()

# Define an optimizer (only parameters with requires_grad=True will be updated)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()

# Zero the gradients
optimizer.zero_grad()

# Forward pass
output = model(batch)

# Compute loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Verify gradients before optimizer step
print("\nGradients before optimizer step:")
grad_norm_st = spatio_temporal_layer.weight.grad.norm().item() if spatio_temporal_layer.weight.grad is not None else 0
grad_norm_bn = model.batch_norm.weight.grad.norm().item() if model.batch_norm.weight.grad is not None else 0
print(f"Gradient norm for spatio_temporal_layer: {grad_norm_st}")
print(f"Gradient norm for batch_norm layer: {grad_norm_bn}")

# Update weights
optimizer.step()

# Check total amount of weights and biases
total_weights = spatio_temporal_layer.weight.data.numel()
total_biases = spatio_temporal_layer.bias.data.numel() if spatio_temporal_layer.bias is not None else 0
print(f"\nTotal number of weights: {total_weights}")
print(f"Total number of biases: {total_biases}")

# Check how many weights and biases have changed
weights_changed = (spatio_temporal_layer.weight.data != initial_weights[1]).float().sum().item()
biases_changed = 0
if spatio_temporal_layer.bias is not None:
    biases_changed = (spatio_temporal_layer.bias.data != initial_biases[1]).float().sum().item()
print(f"\nWeights changed: {weights_changed}")

# Assertions to verify that all weights and biases have changed
assert weights_changed > total_weights-6000, "Not all weights have changed!"




Parameter requires_grad status:
spatio_temporal.weight: requires_grad=True
spatio_temporal.bias: requires_grad=True
batch_norm.weight: requires_grad=True
batch_norm.bias: requires_grad=True
fc.weight: requires_grad=False
fc.bias: requires_grad=False

Gradients before optimizer step:
Gradient norm for spatio_temporal_layer: 0.022053884342312813
Gradient norm for batch_norm layer: 0.09037645906209946

Total number of weights: 22000
Total number of biases: 40

Weights changed: 19888.0


In [79]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from shallow import SubjectOneHotConvNet

# Initialize the model
n_chans = 22
n_outputs = 4
num_subjects = 9
model = SubjectOneHotConvNet(n_chans=n_chans, n_outputs=n_outputs, num_subjects=num_subjects)

# Initialize the spatio_temporal layer's weights to ones for consistency
nn.init.constant_(model.spatio_temporal.weight, 1.0)
if model.spatio_temporal.bias is not None:
    nn.init.constant_(model.spatio_temporal.bias, 0.0)

# Copy initial weights for comparison after training
initial_weights = model.spatio_temporal.weight.data.clone()

# Set requires_grad=True only for the spatio_temporal layer's parameters
for param in model.parameters():
    param.requires_grad = False
for param in model.spatio_temporal.parameters():
    param.requires_grad = True

# Prepare input data for two subjects
batch_size = 3  
n_times = 1001

batch = torch.zeros(batch_size, n_chans, n_times)  # Random non-zero input data
for i in range(n_chans):
    batch[0, i] = i + 1000  # Subject 3
    batch[1, i] = (i + 1000) * 10  # Subject 5

# Set the last time point to contain the subject IDs
subject_ids = [3, 5, 2]  # Subject IDs present in the batch (between 1 and num_subjects)
for i in range(batch_size):
    batch[i, :, -1] = subject_ids[i] * 1000000

# Define target outputs
target = torch.randn(batch_size, n_outputs)  # Random target output

# Define a loss function and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()
optimizer.zero_grad()
output = model(batch)
loss = criterion(output, target)
loss.backward()


# Update weights
optimizer.step()

# Compare weights before and after the update
updated_weights = model.spatio_temporal.weight.data
weight_diff = updated_weights - initial_weights

# Reshape weights to separate them by subject
weight_shape = model.spatio_temporal.weight.shape  # (out_channels, in_channels, H, W)
out_channels = weight_shape[0]
in_channels = weight_shape[1]
kernel_height = weight_shape[2]
kernel_width = weight_shape[3]

# Reshape weights to (num_subjects, num_kernels, in_channels, kernel_height, kernel_width)
weights_reshaped = updated_weights.view(num_subjects, model.num_kernels, in_channels, kernel_height, kernel_width)
initial_weights_reshaped = initial_weights.view(num_subjects, model.num_kernels, in_channels, kernel_height, kernel_width)

# Compute the number of weights changed for each subject
weights_changed_per_subject = ((weights_reshaped != initial_weights_reshaped).float().sum(dim=(1,2,3,4))).numpy()

# Print the number of weights changed per subject
for subj_idx in range(num_subjects):
    print(f"Weights changed for subject {subj_idx+1}: {weights_changed_per_subject[subj_idx]}")

# Verify that only the weights for the specific subjects have changed
total_weights_per_subject = weights_reshaped[0].numel()
print(f"Total weights per subject: {total_weights_per_subject}")

# Assertions
for subj_idx in range(num_subjects):
    if subj_idx + 1 in subject_ids and subj_idx + 1 != 2:
        assert weights_changed_per_subject[subj_idx] > 0, f"Weights for subject {subj_idx+1} did not change but should have!"
    else:
        assert weights_changed_per_subject[subj_idx] == 0, f"Weights for subject {subj_idx+1} changed but should not have!"

print("\nTest passed: Only the weights for the specific subjects have changed.")


Weights changed for subject 1: 0.0
Weights changed for subject 2: 0.0
Weights changed for subject 3: 17600.0
Weights changed for subject 4: 0.0
Weights changed for subject 5: 15400.0
Weights changed for subject 6: 0.0
Weights changed for subject 7: 0.0
Weights changed for subject 8: 0.0
Weights changed for subject 9: 0.0
Total weights per subject: 22000

Test passed: Only the weights for the specific subjects have changed.


In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# Assuming SubjectSpecificConvNet is defined as before
class SubjectSpecificConvNet(nn.Module):
    def __init__(self, n_chans, n_outputs, n_times=1001, dropout=0.5, num_kernels=40, 
                 kernel_size=25, pool_size=100, num_subjects=9):
        super(SubjectSpecificConvNet, self).__init__()
        self.num_subjects = num_subjects

        # Create a dictionary of spatio-temporal convolutional layers, one per subject
        self.spatio_temporal_layers = nn.ModuleDict({
            f'subject_{i+1}': nn.Conv2d(n_chans, num_kernels, (1, kernel_size))
            for i in range(num_subjects)
        })

        self.pool = nn.AvgPool2d((1, pool_size))
        self.batch_norm = nn.BatchNorm2d(num_kernels)
        self.dropout = nn.Dropout(dropout)
        self.fc = nn.LazyLinear(n_outputs)

    def forward(self, x):
        # Extract subject IDs from the last time point of the first channel
        subject_ids = x[:, 0, -1] / 1000000  # Assuming subject IDs are stored scaled in the last time point
        x = x[:, :, :-1]  # Remove the last time point (which contains the subject ID)

        # Add dimension for Conv2d
        x = torch.unsqueeze(x, dim=2)  # Shape: (batch_size, n_chans, 1, n_times)

        # Prepare a list to collect outputs
        conv_outputs = []

        for i in range(x.size(0)):  # Loop over batch size
            subject_id = int(subject_ids[i].item())
            # Select the appropriate convolutional layer
            conv_layer = self.spatio_temporal_layers[f'subject_{subject_id}']
            # Apply the convolutional layer to the i-th sample
            xi = x[i].unsqueeze(0)  # Add batch dimension
            xi = conv_layer(xi)
            conv_outputs.append(xi)

        # Stack the outputs along the batch dimension
        x = torch.cat(conv_outputs, dim=0)

        x = F.elu(x)
        x = self.batch_norm(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)  # Flatten
        x = self.dropout(x)
        x = self.fc(x)
        return x

# Initialize the model
model = SubjectSpecificConvNet(n_chans=22, n_outputs=4)

# Fill spatio-temporal layer weights and biases with 1.0
for st_layer in model.spatio_temporal_layers.values():
    st_layer.weight.data.fill_(1.0)
    st_layer.bias.data.fill_(1.0)

# Set requires_grad=False for all parameters
for param in model.parameters():
    param.requires_grad = False

# Set requires_grad=True for parameters in spatio_temporal_layers
for st_layer in model.spatio_temporal_layers.values():
    for param in st_layer.parameters():
        param.requires_grad = True

print("Parameter requires_grad status:")
for name, param in model.named_parameters():
    print(f"{name}: requires_grad={param.requires_grad}")

# Copy initial weights and biases
initial_weights = {}
initial_biases = {}
for subject_id, st_layer in model.spatio_temporal_layers.items():
    initial_weights[subject_id] = st_layer.weight.data.clone()
    initial_biases[subject_id] = st_layer.bias.data.clone()

# Prepare input data
batch_size = 5
n_chans = 22
n_times = 1001

# Create a batch with zeros
batch = torch.zeros(batch_size, n_chans, n_times)

# Define target outputs
n_outputs = model.fc.out_features
target = torch.zeros(batch_size, n_outputs)

# Use a non-zero input for one specific subject
specific_subject_id = 3  # Subject ID for which we want to induce updates
batch[specific_subject_id - 1] = 1  # Fill with ones for the specific subject

# Set the last time point to contain the subject IDs (from 1 to batch_size)
for i in range(batch_size):
    subject_id = i + 1  # Subject IDs from 1 to 5
    batch[i, :, -1] = subject_id * 1000000  # Multiply by 1,000,000 to match extraction in forward()

# Define a loss function
criterion = nn.MSELoss()

# Define an optimizer (only parameters with requires_grad=True will be updated)
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)

# Training step
model.train()

# Zero the gradients
optimizer.zero_grad()

# Forward pass
output = model(batch)

# Compute loss
loss = criterion(output, target)

# Backward pass
loss.backward()

# Verify gradients before optimizer step
print("\nGradients before optimizer step:")
for subject_id, st_layer in model.spatio_temporal_layers.items():
    grad_norm = st_layer.weight.grad.norm().item() if st_layer.weight.grad is not None else 0
    print(f"Gradient norm for {subject_id}: {grad_norm}")

# Update weights
optimizer.step()

# Check which spatio_temporal_layers have had their weights updated
for subject_id, st_layer in model.spatio_temporal_layers.items():
    weights_changed = not torch.allclose(st_layer.weight.data, initial_weights[subject_id], atol=1e-6)
    biases_changed = not torch.allclose(st_layer.bias.data, initial_biases[subject_id], atol=1e-6)
    print(f"\nWeights changed for {subject_id}: {weights_changed}")
    print(f"Biases changed for {subject_id}: {biases_changed}")

# Assertions to verify subject specificity
for subject_id, st_layer in model.spatio_temporal_layers.items():
    weights_changed = not torch.allclose(st_layer.weight.data, initial_weights[subject_id], atol=1e-6)
    biases_changed = not torch.allclose(st_layer.bias.data, initial_biases[subject_id], atol=1e-6)
    if subject_id == f'subject_{specific_subject_id}':
        assert weights_changed, f"Weights for {subject_id} did not change!"
        assert biases_changed, f"Biases for {subject_id} did not change!"
    else:
        assert not weights_changed, f"Weights for {subject_id} have changed!"
        assert not biases_changed, f"Biases for {subject_id} have changed!"

print("\nTest passed: Only the weights and biases corresponding to the specific subject have changed.")


Parameter requires_grad status:
spatio_temporal_layers.subject_1.weight: requires_grad=True
spatio_temporal_layers.subject_1.bias: requires_grad=True
spatio_temporal_layers.subject_2.weight: requires_grad=True
spatio_temporal_layers.subject_2.bias: requires_grad=True
spatio_temporal_layers.subject_3.weight: requires_grad=True
spatio_temporal_layers.subject_3.bias: requires_grad=True
spatio_temporal_layers.subject_4.weight: requires_grad=True
spatio_temporal_layers.subject_4.bias: requires_grad=True
spatio_temporal_layers.subject_5.weight: requires_grad=True
spatio_temporal_layers.subject_5.bias: requires_grad=True
spatio_temporal_layers.subject_6.weight: requires_grad=True
spatio_temporal_layers.subject_6.bias: requires_grad=True
spatio_temporal_layers.subject_7.weight: requires_grad=True
spatio_temporal_layers.subject_7.bias: requires_grad=True
spatio_temporal_layers.subject_8.weight: requires_grad=True
spatio_temporal_layers.subject_8.bias: requires_grad=True
spatio_temporal_layers.s

AssertionError: Weights for subject_3 did not change!