In [None]:
import os
import re
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from transformers import ViTFeatureExtractor, ViTForImageClassification, TrainingArguments, Trainer
from sklearn.metrics import accuracy_score
from PIL import Image

MODEL_NAME = 'google/vit-base-patch16-224'
NUM_CLASSES = 2
BATCH_SIZE = 32
EPOCHS = 10
DATA_DIR = "/content/gdrive/MyDrive/eeg_8x8_chunks"

VALID_CLASSES = {4, 5}
LABEL_MAP = {4: 0, 5: 1}

def extract_subject_id(fname):
    m = re.search(r"S\d{3}", fname)
    if m:
        return m.group(0)
    if fname.startswith("sample"):
        try:
            return fname.split("_")[1][:4]
        except Exception:
            return "UNK"
    return "UNK"

def get_subject_ids():
    ids = set()
    for fname in os.listdir(DATA_DIR):
        if fname.endswith(".npz") and fname.startswith("sample"):
            sid = extract_subject_id(fname)
            if sid != "UNK":
                ids.add(sid)
    return sorted(ids)

def _slice(lst, a, b):
    a = min(a, len(lst)); b = min(b, len(lst))
    return lst[a:b]

def load_chunks_by_subject_ids(subject_ids):
    X_all, y_all = [], []
    files = [f for f in os.listdir(DATA_DIR) if f.endswith(".npz") and f.startswith("sample")]
    for sid in subject_ids:
        sid_token = f"_{sid}R"
        sid_files = [f for f in files if sid_token in f]
        for fname in sid_files:
            path = os.path.join(DATA_DIR, fname)
            data = np.load(path, allow_pickle=True)
            X, y = data["X"], data["y"]
            if isinstance(y, np.ndarray):
                y = y.item()
            if y not in VALID_CLASSES:
                continue
            if X.ndim != 3 or X.shape[1] != 8 or X.shape[2] != 8:
                continue
            X_all.append(X.astype(np.float32))
            y_all.append(LABEL_MAP[y])
    if not X_all:
        raise RuntimeError("No samples after subject-based loading.")
    X_all = np.stack(X_all)
    y_all = np.array(y_all, dtype=int)
    return X_all, y_all

subject_ids = get_subject_ids()
train_ids = _slice(subject_ids, 0, 70)
dev_ids   = _slice(subject_ids, 71, 86)
test_ids  = _slice(subject_ids, 87, 103)

X_train, y_train = load_chunks_by_subject_ids(train_ids)
X_dev,   y_dev   = load_chunks_by_subject_ids(dev_ids)
X_test,  y_test  = load_chunks_by_subject_ids(test_ids)

print("Data Shapes:")
print("Train:", X_train.shape, y_train.shape)
print("Dev:",   X_dev.shape,   y_dev.shape)
print("Test:",  X_test.shape,  y_test.shape)

feature_extractor = ViTFeatureExtractor.from_pretrained(MODEL_NAME)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((224, 224)),
    transforms.Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)
])

class EEGImageDataset(Dataset):
    def __init__(self, X, y, transform):
        self.X = X
        self.y = y
        self.transform = transform
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        img = self.X[idx]
        img = torch.tensor(img).float()
        img = torch.mean(img, dim=0, keepdim=True)
        img = img.repeat(3, 1, 1)
        img = img.permute(1, 2, 0).numpy()
        img = Image.fromarray((img * 255).astype(np.uint8))
        img = self.transform(img)
        return {"pixel_values": img, "labels": torch.tensor(self.y[idx], dtype=torch.long)}

train_dataset = EEGImageDataset(X_train, y_train, transform)
dev_dataset   = EEGImageDataset(X_dev,   y_dev,   transform)
test_dataset  = EEGImageDataset(X_test,  y_test,  transform)

model = ViTForImageClassification.from_pretrained(MODEL_NAME, num_labels=11, ignore_mismatched_sizes=True)

training_args = TrainingArguments(
    output_dir="./vit_pretrained_results",
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",
    num_train_epochs=40,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    learning_rate=5e-5,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    logging_dir="./logs",
    report_to="none",
    save_total_limit=1,
)

def compute_metrics(pred):
    preds = np.argmax(pred.predictions, axis=1)
    acc = accuracy_score(pred.label_ids, preds)
    return {"accuracy": acc}

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,
    compute_metrics=compute_metrics,
    tokenizer=feature_extractor
)

trainer.train()

print("\nFinal Evaluation")
test_results = trainer.evaluate(test_dataset)
print(test_results)

model.save_pretrained("vit_pretrained_model")
print("Saved pretrained ViT model.")
