In [1]:
from config import Config
from datasets import BroderickDataset
from preprocessor import Preprocessor
from utils import prepare_inputs, EEGDataset
from model import EEGAdapterLlamaForCausalLM

import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

  from .autonotebook import tqdm as notebook_tqdm
Hostname dhcp-10-29-160-100.dyn.MIT.EDU not defined in /conf/study_paths/study_paths.yaml. Using default paths.


In [2]:
config = Config("config/config.yaml")
EEG = BroderickDataset(config)
PROCESSOR = Preprocessor(config, EEG=EEG)
eegs, subjects, inputs, labels = prepare_inputs(config, *PROCESSOR['ALL'])

datasets/broderick/EEG/
datasets/broderick/Stimuli/Text/
Retrieving S01...
Processing S01...
datasets/brennan_hale/EEG/
datasets/brennan_hale/Stimuli/Text/
Retrieving S01...
datasets/broderick/EEG/
datasets/broderick/Stimuli/Text/
Retrieving S01...


In [3]:
braindecoder = EEGAdapterLlamaForCausalLM(config, config.llama.model_name, config.llama.token)

Loading checkpoint shards: 100%|██████████| 2/2 [00:58<00:00, 29.04s/it]


In [4]:
for name, param in braindecoder.named_parameters():
    if param.requires_grad:
        print(f"Parameter: {name}, Size: {param.size()}")
    else:
        print(f"Frozen Parameter: {name}, Size: {param.size()}")

Frozen Parameter: encoder.merger.heads, Size: torch.Size([270, 2048])
Frozen Parameter: encoder.initial_linear.0.weight, Size: torch.Size([270, 270, 1])
Frozen Parameter: encoder.initial_linear.0.bias, Size: torch.Size([270])
Frozen Parameter: encoder.subject_layers.weights, Size: torch.Size([33, 270, 270])
Frozen Parameter: encoder.final.0.weight, Size: torch.Size([640, 320, 1])
Frozen Parameter: encoder.final.0.bias, Size: torch.Size([640])
Frozen Parameter: encoder.final.2.weight, Size: torch.Size([640, 1024, 1])
Frozen Parameter: encoder.final.2.bias, Size: torch.Size([1024])
Frozen Parameter: encoder.encoders.meg.sequence.0.0.weight, Size: torch.Size([320, 270, 3])
Frozen Parameter: encoder.encoders.meg.sequence.0.0.bias, Size: torch.Size([320])
Frozen Parameter: encoder.encoders.meg.sequence.0.1.weight, Size: torch.Size([320])
Frozen Parameter: encoder.encoders.meg.sequence.0.1.bias, Size: torch.Size([320])
Frozen Parameter: encoder.encoders.meg.sequence.1.0.weight, Size: torch.S

In [5]:
dataset = EEGDataset(eegs, subjects, inputs, labels)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optimizer = Adam(braindecoder.parameters(), lr=config.train.learning_rate)

In [6]:
for (eeg, subject, input_data), label in dataloader:
    print(f'eeg.shape: {eeg.shape}')
    print(f'subject.shape: {subject.shape}')
    print(f'input_ids: {input_data.shape}')
    print(f'label_ids: {label.shape}')
    break

eeg.shape: torch.Size([1, 61, 360])
subject.shape: torch.Size([1])
input_ids: torch.Size([1, 100])
label_ids: torch.Size([1, 100])


In [7]:
def train(model, dataloader, optimizer, epochs, device):

    model.train()
    model.to(device)
    for epoch in range(epochs):
        total_loss = 0
        for (eeg, subject, input_data), labels in dataloader:
            eeg = eeg.to(device)
            input_data = input_data.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(input_ids=input_data, labels=labels, eegs=eeg.float(), subject_index=subject)
            loss = outputs.loss
            total_loss += loss.item()
            print(loss.item())

            loss.backward()

            optimizer.step()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

device = torch.device("cpu")
train(braindecoder, dataloader, optimizer, config.train.epochs, device)

tensor([[-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 6452,  287,
         2307, 8023, 1497,  817,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0]])
18.26792335510254
