In [None]:
# Import necessary libraries
import json
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torch
from sklearn.preprocessing import StandardScaler
import os
import pickle
import lightning as L
import torch.nn as nn

# Load data function
def load_data(json_paths):
    data = []
    for path in json_paths:
        with open(path, 'r') as f:
            data.extend(json.load(f)["task_name"])
    return data

# Dataset class
class EEGEyeNetDataset(Dataset):
    def __init__(self, data, task_type):
        self.data = data
        self.task_type = task_type
        self.scaler = StandardScaler()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        input_files = sample["input"]
        start = sample["start"]
        length = sample["length"]
        
        signals = []
        for file in input_files:
            with open(file, 'rb') as f:
                signal = pickle.load(f)
                signal = signal[:, start:start+length]
                signals.append(signal)
        
        signals = np.concatenate(signals, axis=1)
        signals = self.scaler.fit_transform(signals.T).T  # Standardize
        
        output = sample["output"]
        
        return torch.tensor(signals, dtype=torch.float32), torch.tensor(output, dtype=torch.float32 if self.task_type == "Regression" else torch.long)

# Load datasets
train_json = "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_train.json"
val_json = "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_val.json"
test_json = "/itet-stor/maxihuber/deepeye_storage/eegeyenet_tasks/EEGEyeNet_Direction_test.json"

train_data = load_data([train_json])
val_data = load_data([val_json])
test_data = load_data([test_json])

task_type = "Regression"  # or "Classification" based on the task

train_dataset = EEGEyeNetDataset(train_data, task_type)
val_dataset = EEGEyeNetDataset(val_data, task_type)
test_dataset = EEGEyeNetDataset(test_data, task_type)

train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

# LightningModule class for fine-tuning
class FineTuningModel(L.LightningModule):
    def __init__(self, encoder, task_type, learning_rate=1e-4):
        super(FineTuningModel, self).__init__()
        self.encoder = encoder
        for param in self.encoder.parameters():
            param.requires_grad = False
        
        self.task_type = task_type
        self.learning_rate = learning_rate
        
        if task_type == "Regression":
            self.head = nn.Linear(encoder.encoder_embed_dim, 1)
            self.criterion = nn.MSELoss()
        else:
            self.head = nn.Linear(encoder.encoder_embed_dim, len(set([d['output'] for d in train_data])))  # Adjust output classes
            self.criterion = nn.CrossEntropyLoss()

    def forward(self, x):
        x_emb, _, _, _ = self.encoder(x)
        x_emb = x_emb[:, 0]  # Assuming the first token is the cls token
        return self.head(x_emb)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.head.parameters(), lr=self.learning_rate)

# Load the checkpoint
chkpt_path = "/itet-stor/maxihuber/net_scratch/checkpoints/977598/epoch=0-step=32807-val_loss=133.55.ckpt"
checkpoint = torch.load(chkpt_path, map_location=torch.device('cpu'))
state_dict = checkpoint['state_dict']

# Initialize the encoder and load the state dict
encoder = ModularMaskedAutoencoderViTRoPE(channel_names_path='path_to_channel_names.json').encoder
encoder.load_state_dict({k.replace("encoder.", ""): v for k, v in state_dict.items() if "encoder" in k})

# Instantiate the fine-tuning model
fine_tuning_model = FineTuningModel(encoder, task_type)

# Train the model
trainer = L.Trainer(max_epochs=10, gpus=1 if torch.cuda.is_available() else 0)
trainer.fit(fine_tuning_model, train_dataloaders=train_loader, val_dataloaders=val_loader)

# Evaluate the model
trainer.validate(fine_tuning_model, val_dataloaders=val_loader)
