In [None]:
import sys
import torch
from torch.utils.data import DataLoader
from torch.optim import Adam

sys.path.append('../src')
from config import Config
from datasets import BroderickDataset
from preprocessor import Preprocessor
from utils import prepare_inputs, EEGDataset
from model import EEGAdapterLlamaForCausalLM

In [None]:
config = Config("config/config.yaml")
EEG = BroderickDataset(config)
PROCESSOR = Preprocessor(config, EEG=EEG)
eegs, subjects, inputs, labels = prepare_inputs(config, *PROCESSOR['ALL'])

In [None]:
braindecoder = EEGAdapterLlamaForCausalLM(config, config.llama.model_name, config.llama.token)

In [None]:
for name, param in braindecoder.named_parameters():
    if param.requires_grad:
        print(f"Parameter: {name}, Size: {param.size()}")
    else:
        print(f"Frozen Parameter: {name}, Size: {param.size()}")

In [5]:
dataset = EEGDataset(eegs, subjects, inputs, labels)
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optimizer = Adam(braindecoder.parameters(), lr=config.train.learning_rate)

In [None]:
for (eeg, subject, input_data), label in dataloader:
    print(f'eeg.shape: {eeg.shape}')
    print(f'subject.shape: {subject.shape}')
    print(f'input_ids: {input_data.shape}')
    print(f'label_ids: {label.shape}')
    break

In [None]:
def train(model, dataloader, optimizer, epochs, device):

    model.train()
    model.to(device)
    for epoch in range(epochs):
        total_loss = 0
        for (eeg, subject, input_data), labels in dataloader:
            eeg = eeg.to(device)
            input_data = input_data.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            outputs = model(input_ids=input_data, labels=labels, eegs=eeg.float(), subject_index=subject)
            loss = outputs.loss
            total_loss += loss.item()
            print(loss.item())

            loss.backward()

            optimizer.step()

        avg_loss = total_loss / len(dataloader)
        print(f"Epoch {epoch+1}/{epochs}, Average Loss: {avg_loss:.4f}")

device = torch.device("cpu")
train(braindecoder, dataloader, optimizer, config.train.epochs, device)