In [None]:
# --- Install dependencies (Colab) ---
!pip install -q torch torchvision torchaudio scikit-learn matplotlib

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m35.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m88.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m45.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m11.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m34.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m19.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# --- Imports ---
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from google.colab import drive
from sklearn.metrics import mean_squared_error
import numpy as np
import math
import os
from tqdm import tqdm
import matplotlib.pyplot as plt

In [None]:
# --- Mount Google Drive ---
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# --- Constants for screen setup ---
SCREEN_WIDTH_PX = 800
SCREEN_HEIGHT_PX = 600
SCREEN_WIDTH_MM = 518
SCREEN_HEIGHT_MM = 324
VIEWING_DISTANCE_MM = 680

In [None]:
# --- Positional Encoding ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=1200):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:, :x.size(1)]

# --- Patch Embedding via Conv ---
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, emb_size=256):
        super().__init__()
        self.proj = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),
            nn.GELU(),
            nn.Conv2d(64, emb_size, kernel_size=3, stride=2, padding=1),
            nn.GELU()
        )

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.proj(x)
        B, E, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)
        return x

# --- Transformer Encoder Block ---
class TransformerBlock(nn.Module):
    def __init__(self, emb_size=256, heads=4, ff_dim=512, dropout=0.2):
        super().__init__()
        self.attn = nn.MultiheadAttention(emb_size, heads, dropout=dropout, batch_first=True)
        self.ff = nn.Sequential(
            nn.Linear(emb_size, ff_dim),
            nn.GELU(),
            nn.Linear(ff_dim, emb_size)
        )
        self.norm1 = nn.LayerNorm(emb_size, eps=1e-5)
        self.norm2 = nn.LayerNorm(emb_size, eps=1e-5)
        self.drop = nn.Dropout(dropout)

    def forward(self, x):
        attn_out, _ = self.attn(x, x, x)
        x = self.norm1(x + self.drop(attn_out))
        ff_out = self.ff(x)
        x = self.norm2(x + self.drop(ff_out))
        return x

# --- Full Model with Attention Pooling ---
class EEGTransformerRegressor(nn.Module):
    def __init__(self, emb_size=256, num_layers=4):
        super().__init__()
        self.patch_embed = PatchEmbedding(1, emb_size)
        self.pos_encoder = PositionalEncoding(emb_size)
        self.transformer = nn.Sequential(*[
            TransformerBlock(emb_size) for _ in range(num_layers)
        ])
        self.attn_pool = nn.MultiheadAttention(emb_size, num_heads=4, batch_first=True)
        self.regressor = nn.Sequential(
            nn.LayerNorm(emb_size, eps=1e-5),
            nn.Linear(emb_size, 64),
            nn.GELU(),
            nn.Linear(64, 2)
        )

    def forward(self, x):
        x = self.patch_embed(x)
        x = self.pos_encoder(x)
        x = self.transformer(x)
        pooled, _ = self.attn_pool(x, x, x)
        return self.regressor(pooled[:, 0])

# --- Dataset ---
class EEGNPZDataset(Dataset):
    def __init__(self, file_paths, file_fraction=0.75):
        self.samples = []
        for path in file_paths:
            data = np.load(path)
            eeg = data['EEG']
            labels = data['labels'][:, -2:]
            n = int(len(eeg) * file_fraction)
            for x, y in zip(eeg[:n], labels[:n]):
                if not np.isnan(x).any() and not np.isnan(y).any():
                    deg_x, deg_y = pixels_to_degrees(y[0], y[1])
                    if abs(deg_x) < 30 and abs(deg_y) < 30:
                        x = (x - np.mean(x)) / (np.std(x) + 1e-6)
                        self.samples.append((x, y))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        x, y = self.samples[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)

In [None]:
def pixels_to_degrees(x_px, y_px):
    x_mm = (x_px - SCREEN_WIDTH_PX / 2) * (SCREEN_WIDTH_MM / SCREEN_WIDTH_PX)
    y_mm = (y_px - SCREEN_HEIGHT_PX / 2) * (SCREEN_HEIGHT_MM / SCREEN_HEIGHT_PX)
    angle_x = np.degrees(np.arctan2(x_mm, VIEWING_DISTANCE_MM))
    angle_y = np.degrees(np.arctan2(y_mm, VIEWING_DISTANCE_MM))
    return angle_x, angle_y

# --- RMSE in degrees of visual angle ---
def rmse_degrees(preds, targets):
    preds_deg = np.array([pixels_to_degrees(x, y) for x, y in preds])
    targets_deg = np.array([pixels_to_degrees(x, y) for x, y in targets])
    return np.sqrt(np.mean((preds_deg - targets_deg) ** 2))
def pixels_to_mm(x_px, y_px):
    x_mm = (x_px - SCREEN_WIDTH_PX / 2) * (SCREEN_WIDTH_MM / SCREEN_WIDTH_PX)
    y_mm = (y_px - SCREEN_HEIGHT_PX / 2) * (SCREEN_HEIGHT_MM / SCREEN_HEIGHT_PX)
    return x_mm, y_mm
def rmse_mm(preds, targets):
    preds_mm = np.array([pixels_to_mm(x, y) for x, y in preds])
    targets_mm = np.array([pixels_to_mm(x, y) for x, y in targets])
    return np.sqrt(np.mean((preds_mm - targets_mm) ** 2))

In [None]:
# --- Training Loop ---
def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs, device, save_dir, patience=5):
    scaler = torch.amp.GradScaler('cuda')
    best_rmse = float('inf')
    train_losses = []
    val_rmses = []
    patience_counter = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0

        for batch_idx, (inputs, targets) in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}")):
            inputs, targets = inputs.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            optimizer.zero_grad(set_to_none=True)

            with torch.amp.autocast(device_type='cuda'):
                outputs = model(inputs)

            if torch.isnan(outputs).any():
                print(f"❌ NaNs detected in outputs at batch {batch_idx}")
                del inputs, targets, outputs
                torch.cuda.empty_cache()
                continue

            loss = criterion(outputs, targets)

            if not torch.isfinite(loss).all():
                print(f"⚠️ Skipping batch {batch_idx} with non-finite loss: {loss.item()}")
                del inputs, targets, outputs, loss
                torch.cuda.empty_cache()
                continue

            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            running_loss += loss.item() * inputs.size(0)

            del inputs, targets, outputs, loss
            torch.cuda.empty_cache()

        avg_train_loss = running_loss / len(train_loader.dataset)
        train_losses.append(avg_train_loss)

        model.eval()
        preds, gts = [], []
        with torch.no_grad():
            for inputs, targets in val_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                preds.append(outputs.cpu().numpy())
                gts.append(targets.cpu().numpy())

        preds = np.concatenate(preds)
        gts = np.concatenate(gts)
        rmse = np.sqrt(mean_squared_error(gts, preds))
        val_rmses.append(rmse)
        rmse_deg = rmse_degrees(preds, gts)
        rmse_mm_val = rmse_mm(preds, gts)

        scheduler.step()

        if rmse < best_rmse:
            best_rmse = rmse
            patience_counter = 0
            torch.save(model.state_dict(), os.path.join(save_dir, "best_model.pt"))
            print("✅ Saved new best model.")
        else:
            patience_counter += 1
            print(f"⏳ No improvement. Patience: {patience_counter}/{patience}")
            if patience_counter >= patience:
                print("🛑 Early stopping triggered.")
                break

        print(f"Epoch [{epoch+1}/{num_epochs}] | Train Loss: {avg_train_loss:.4f} | Val RMSE: {rmse:.2f}px | {rmse_deg:.2f}° | {rmse_mm_val:.2f}mm")

    plt.figure()
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_rmses, label='Validation RMSE (px)')
    plt.xlabel('Epoch')
    plt.ylabel('Loss / RMSE')
    plt.legend()
    plt.title('Training and Validation Loss')
    plt.savefig(os.path.join(save_dir, 'loss_plot.png'))
    plt.close()

In [None]:
data_dir = "/content/drive/MyDrive/data"
files = [
    "Direction_task_with_dots_synchronised_max.npz",
    "Direction_task_with_dots_synchronised_min.npz",
    "Direction_task_with_processing_speed_synchronised_max.npz",
    "Direction_task_with_processing_speed_synchronised_min.npz",
]
file_paths = [os.path.join(data_dir, f) for f in files]

dataset = EEGNPZDataset(file_paths)
train_size = int(0.8 * len(dataset))
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, len(dataset) - train_size])

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=4, pin_memory=True, persistent_workers=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=4, pin_memory=True, persistent_workers=True)

In [10]:
# --- Execution ---
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = EEGTransformerRegressor().to(device)

    criterion = nn.SmoothL1Loss(beta=10.0)  # beta controls transition to L1
    optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)
    scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=5)

    save_dir = "/content/drive/MyDrive/cnn_transformer/outliers/Try1"
    os.makedirs(save_dir, exist_ok=True)

    train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, num_epochs=100, device=device, save_dir=save_dir, patience=5)

Epoch 1: 100%|██████████| 925/925 [05:49<00:00,  2.64it/s]


✅ Saved new best model.
Epoch [1/100] | Train Loss: 89.4318 | Val RMSE: 131.75px | 7.03° | 85.30mm


Epoch 2: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [2/100] | Train Loss: 68.9017 | Val RMSE: 126.83px | 6.76° | 82.12mm


Epoch 3: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [3/100] | Train Loss: 65.2511 | Val RMSE: 115.15px | 6.14° | 74.55mm


Epoch 4: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [4/100] | Train Loss: 57.2298 | Val RMSE: 107.94px | 5.77° | 69.89mm


Epoch 5: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [5/100] | Train Loss: 50.2155 | Val RMSE: 103.03px | 5.50° | 66.71mm


Epoch 6: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [6/100] | Train Loss: 47.1611 | Val RMSE: 91.07px | 4.85° | 58.96mm


Epoch 7: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [7/100] | Train Loss: 42.7516 | Val RMSE: 88.13px | 4.70° | 57.06mm


Epoch 8: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [8/100] | Train Loss: 38.9840 | Val RMSE: 84.46px | 4.50° | 54.69mm


Epoch 9: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [9/100] | Train Loss: 37.1292 | Val RMSE: 81.36px | 4.33° | 52.68mm


Epoch 10: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [10/100] | Train Loss: 36.2607 | Val RMSE: 81.78px | 4.36° | 52.95mm


Epoch 11: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [11/100] | Train Loss: 37.1679 | Val RMSE: 81.60px | 4.34° | 52.83mm


Epoch 12: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 3/5
Epoch [12/100] | Train Loss: 36.3218 | Val RMSE: 88.09px | 4.69° | 57.04mm


Epoch 13: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [13/100] | Train Loss: 36.0122 | Val RMSE: 80.69px | 4.30° | 52.24mm


Epoch 14: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [14/100] | Train Loss: 34.4102 | Val RMSE: 78.83px | 4.20° | 51.04mm


Epoch 15: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [15/100] | Train Loss: 33.2366 | Val RMSE: 78.98px | 4.21° | 51.14mm


Epoch 16: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [16/100] | Train Loss: 34.6044 | Val RMSE: 78.73px | 4.19° | 50.97mm


Epoch 17: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [17/100] | Train Loss: 34.0916 | Val RMSE: 79.70px | 4.25° | 51.60mm


Epoch 18: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [18/100] | Train Loss: 33.0014 | Val RMSE: 76.35px | 4.07° | 49.44mm


Epoch 19: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [19/100] | Train Loss: 31.6014 | Val RMSE: 77.51px | 4.13° | 50.18mm


Epoch 20: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [20/100] | Train Loss: 30.7279 | Val RMSE: 75.78px | 4.04° | 49.06mm


Epoch 21: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [21/100] | Train Loss: 33.9997 | Val RMSE: 76.13px | 4.05° | 49.29mm


Epoch 22: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [22/100] | Train Loss: 32.1897 | Val RMSE: 75.78px | 4.03° | 49.06mm


Epoch 23: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [23/100] | Train Loss: 30.0959 | Val RMSE: 74.59px | 3.97° | 48.29mm


Epoch 24: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [24/100] | Train Loss: 28.9778 | Val RMSE: 74.97px | 3.99° | 48.54mm


Epoch 25: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [25/100] | Train Loss: 28.3042 | Val RMSE: 73.82px | 3.93° | 47.79mm


Epoch 26: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [26/100] | Train Loss: 30.5782 | Val RMSE: 77.46px | 4.13° | 50.15mm


Epoch 27: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [27/100] | Train Loss: 29.0073 | Val RMSE: 75.67px | 4.03° | 48.99mm


Epoch 28: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 3/5
Epoch [28/100] | Train Loss: 27.9978 | Val RMSE: 74.09px | 3.95° | 47.97mm


Epoch 29: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [29/100] | Train Loss: 26.7288 | Val RMSE: 72.83px | 3.88° | 47.15mm


Epoch 30: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [30/100] | Train Loss: 26.0476 | Val RMSE: 72.48px | 3.86° | 46.93mm


Epoch 31: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [31/100] | Train Loss: 28.5942 | Val RMSE: 74.28px | 3.95° | 48.09mm


Epoch 32: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [32/100] | Train Loss: 27.9446 | Val RMSE: 73.23px | 3.90° | 47.41mm


Epoch 33: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 3/5
Epoch [33/100] | Train Loss: 26.5660 | Val RMSE: 73.40px | 3.91° | 47.53mm


Epoch 34: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 4/5
Epoch [34/100] | Train Loss: 25.0953 | Val RMSE: 73.09px | 3.89° | 47.32mm


Epoch 35: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [35/100] | Train Loss: 23.8702 | Val RMSE: 72.18px | 3.84° | 46.73mm


Epoch 36: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [36/100] | Train Loss: 26.4805 | Val RMSE: 74.32px | 3.96° | 48.12mm


Epoch 37: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [37/100] | Train Loss: 27.4224 | Val RMSE: 72.63px | 3.87° | 47.03mm


Epoch 38: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 3/5
Epoch [38/100] | Train Loss: 25.1186 | Val RMSE: 72.50px | 3.86° | 46.94mm


Epoch 39: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [39/100] | Train Loss: 22.9626 | Val RMSE: 72.15px | 3.84° | 46.71mm


Epoch 40: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [40/100] | Train Loss: 22.1687 | Val RMSE: 71.70px | 3.82° | 46.42mm


Epoch 41: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [41/100] | Train Loss: 25.0521 | Val RMSE: 72.43px | 3.86° | 46.89mm


Epoch 42: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


✅ Saved new best model.
Epoch [42/100] | Train Loss: 23.2630 | Val RMSE: 71.25px | 3.79° | 46.13mm


Epoch 43: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 1/5
Epoch [43/100] | Train Loss: 22.2588 | Val RMSE: 72.53px | 3.86° | 46.96mm


Epoch 44: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 2/5
Epoch [44/100] | Train Loss: 20.9991 | Val RMSE: 72.21px | 3.84° | 46.76mm


Epoch 45: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 3/5
Epoch [45/100] | Train Loss: 20.2004 | Val RMSE: 71.77px | 3.82° | 46.47mm


Epoch 46: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 4/5
Epoch [46/100] | Train Loss: 23.0403 | Val RMSE: 73.38px | 3.91° | 47.51mm


Epoch 47: 100%|██████████| 925/925 [05:49<00:00,  2.65it/s]


⏳ No improvement. Patience: 5/5
🛑 Early stopping triggered.
