In [1]:
# [0] Environment & Device

import os
import json
from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch import nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


Device: cuda


In [2]:
# [1] Load train_dataset_recovered.pt and inspect

dataset_dir = "torch_datasets"
train_dataset_path = os.path.join(dataset_dir, "train_dataset_recovered.pt")

train_dataset = torch.load(
    train_dataset_path,
    map_location="cpu",
    weights_only=False,
)

X_train, y_train = train_dataset.tensors
print("Loaded train dataset from:", train_dataset_path)
print("X_train:", X_train.shape, "| dtype:", X_train.dtype, "| device:", X_train.device)
print("y_train:", y_train.shape, "| dtype:", y_train.dtype, "| device:", y_train.device)
print("Label counts (train):", torch.bincount(y_train).tolist())


Loaded train dataset from: torch_datasets\train_dataset_recovered.pt
X_train: torch.Size([4665, 1, 22, 1001]) | dtype: torch.float32 | device: cpu
y_train: torch.Size([4665]) | dtype: torch.int64 | device: cpu
Label counts (train): [1166, 1167, 1166, 1166]


In [None]:
# [2] Import TCFormer class and load best weights + config

from tcformer_module import TCFormer 

ckpt_dir = Path("checkpoints_tcformer")
state_path = ckpt_dir / "tcformer_state_dict.pth"
config_path = ckpt_dir / "tcformer_config.json"

with open(config_path, "r") as f:
    cfg = json.load(f)

n_channels = cfg["n_channels"]
n_classes = cfg["n_classes"]
model_args = cfg["model_args"]

model = TCFormer(
    n_channels=n_channels,
    n_classes=n_classes,
    **model_args,
).to(device)

state_dict = torch.load(state_path, map_location=device)
model.load_state_dict(state_dict)
model.eval()

print("Loaded model from:", state_path)
print("n_channels:", n_channels, "| n_classes:", n_classes)


Loaded model from: checkpoints_tcformer\tcformer_state_dict.pth
n_channels: 22 | n_classes: 4


In [4]:
# [3] DataLoader for train set (no shuffle to keep index order)

batch_size = 64

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=False,
)

print("Number of train batches:", len(train_loader))


Number of train batches: 73


In [7]:
# [4] Run inference on train set and collect correctly classified trial indices

correct_indices = []
correct_labels = []

with torch.no_grad():
    start_idx = 0
    for batch_X, batch_y in train_loader:
        batch_X = batch_X.to(device)
        batch_y = batch_y.to(device)

        logits = model(batch_X)
        preds = logits.argmax(dim=1)

        is_correct = (preds == batch_y)
        if is_correct.any():
            idx_in_batch = torch.nonzero(is_correct, as_tuple=False).squeeze(1)
            global_idx = start_idx + idx_in_batch.cpu()
            correct_indices.append(global_idx)
            correct_labels.append(batch_y[idx_in_batch].cpu())

        start_idx += batch_y.size(0)

if len(correct_indices) > 0:
    correct_indices = torch.cat(correct_indices).numpy()
    correct_labels = torch.cat(correct_labels).numpy()
else:
    correct_indices = np.array([], dtype=int)
    correct_labels = np.array([], dtype=int)

total_trials = len(train_dataset)
num_correct = len(correct_indices)
acc = num_correct / total_trials if total_trials > 0 else 0.0

print(f"Total trials      : {total_trials}")
print(f"Correct trials    : {num_correct}")
print(f"Accuracy(%) : {acc*100:.2f}")


Total trials      : 4665
Correct trials    : 4138
Accuracy(%) : 88.70


In [6]:
# [5] Print class serials of correctly classified trials (with optional indices)

print("=== First 50 correctly classified trials (index, class) ===")
for i in range(min(50, len(correct_indices))):
    print(f"{i:03d}: idx={int(correct_indices[i])}, class={int(correct_labels[i])}")

print("\n=== Class serial list of correctly classified trials ===")
print("class_serials =")
print(correct_labels.tolist())

print("\nPer-class correct counts:")
if len(correct_labels) > 0:
    unique, counts = np.unique(correct_labels, return_counts=True)
    for c, cnt in zip(unique, counts):
        print(f"class {c}: {cnt}")
else:
    print("No correctly classified trials.")


=== First 50 correctly classified trials (index, class) ===
000: idx=1, class=0
001: idx=2, class=3
002: idx=3, class=0
003: idx=4, class=2
004: idx=5, class=1
005: idx=6, class=3
006: idx=7, class=3
007: idx=8, class=3
008: idx=9, class=1
009: idx=10, class=2
010: idx=11, class=0
011: idx=12, class=2
012: idx=13, class=2
013: idx=14, class=0
014: idx=15, class=2
015: idx=16, class=2
016: idx=17, class=1
017: idx=18, class=1
018: idx=19, class=1
019: idx=20, class=0
020: idx=21, class=3
021: idx=22, class=2
022: idx=23, class=1
023: idx=24, class=3
024: idx=25, class=1
025: idx=26, class=0
026: idx=27, class=3
027: idx=29, class=2
028: idx=30, class=2
029: idx=31, class=2
030: idx=32, class=3
031: idx=33, class=3
032: idx=34, class=3
033: idx=35, class=3
034: idx=36, class=0
035: idx=37, class=2
036: idx=38, class=1
037: idx=39, class=0
038: idx=40, class=3
039: idx=41, class=1
040: idx=42, class=0
041: idx=44, class=1
042: idx=45, class=0
043: idx=46, class=1
044: idx=48, class=3
045:

In [None]:
# [6] Find correctly classified trials with desired class serial

# 클래스 시퀀스
target_serial = [2, 0, 2, 1, 2, 2, 1, 2, 0, 2, 3]

labels = correct_labels
indices = correct_indices

if len(labels) == 0:
    raise RuntimeError("No correctly classified trials available.")

selected_global_indices = []
selected_classes = []
last_pos = 0

for cls in target_serial:
    if last_pos >= len(labels):
        raise ValueError(f"Not enough correctly classified trials to match class {cls}.")
    mask = (labels[last_pos:] == cls)
    if not mask.any():
        raise ValueError(f"Cannot find class {cls} in correct_labels starting from position {last_pos}.")
    offset = mask.argmax() 
    pos = last_pos + int(offset)  
    selected_global_indices.append(int(indices[pos]))
    selected_classes.append(int(labels[pos]))
    last_pos = pos + 1

print("Selected global trial indices:")
print(selected_global_indices)
print("\nSelected class serial:")
print(selected_classes)

assert selected_classes == target_serial, "Selected class serial does not match target_serial."
print("\nClass serial matches target_serial.")


Selected global trial indices:
[4, 11, 12, 17, 22, 29, 38, 49, 57, 61, 69]

Selected class serial:
[2, 0, 2, 1, 2, 2, 1, 2, 0, 2, 3]

Class serial matches target_serial.


In [10]:
# [7] Save selected trials:
#     - time-series + class labels → .pt
#     - index + class serial + meta → JSON
#     (saved in a dedicated folder)

import json
from pathlib import Path
import torch

# 전제:
# - target_serial
# - selected_global_indices
# - selected_classes
# - X_train: [N, 1, 22, T] (CPU tensor)

if len(selected_global_indices) == 0:
    raise RuntimeError("No selected trials available. Run [6] first.")

save_dir = Path("selected_trials")
save_dir.mkdir(parents=True, exist_ok=True)

# --------------------------------------------------
# 1) 시계열 + 클래스 레이블을 .pt로 저장
# --------------------------------------------------
# timeseries: [K, 1, 22, T]
selected_timeseries = X_train[selected_global_indices]

pt_payload = {
    "timeseries": selected_timeseries,                 # torch.Tensor
    "labels": torch.tensor(selected_classes, dtype=torch.long),
    "trial_indices": torch.tensor(selected_global_indices, dtype=torch.long),
}

pt_path = save_dir / "selected_timeseries_with_labels.pt"
torch.save(pt_payload, pt_path)

# --------------------------------------------------
# 2) 인덱스 + 클래스 시리얼 + 메타정보를 JSON으로 저장
# --------------------------------------------------
meta = {
    "target_serial": target_serial,
    "trial_indices": [int(i) for i in selected_global_indices],
    "class_serial": [int(c) for c in selected_classes],
    "pt_file": str(pt_path),
    "timeseries_shape": {
        "num_trials": int(selected_timeseries.shape[0]),
        "num_channels_dim": int(selected_timeseries.shape[1]),  # usually 1
        "num_eeg_channels": int(selected_timeseries.shape[2]),  # 22
        "time_length": int(selected_timeseries.shape[3]),
    },
}

json_path = save_dir / "selected_meta.json"
with open(json_path, "w") as f:
    json.dump(meta, f, indent=2)

print("Saved .pt file   :", pt_path)
print("Saved JSON meta  :", json_path)
print("Timeseries shape :", tuple(selected_timeseries.shape))
print("Class serial     :", selected_classes)


Saved .pt file   : selected_trials\selected_timeseries_with_labels.pt
Saved JSON meta  : selected_trials\selected_meta.json
Timeseries shape : (11, 1, 22, 1001)
Class serial     : [2, 0, 2, 1, 2, 2, 1, 2, 0, 2, 3]
