In [1]:
import os
import glob
import random 
import itertools

import numpy as np
import pandas as pd

from pathlib import Path
from scipy import signal

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import Dataset, DataLoader, ConcatDataset, random_split, Subset

from sklearn.svm import SVC
from sklearn import preprocessing
from sklearn.model_selection import KFold
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import accuracy_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, OneHotEncoder

from esn_model import ESN, ReadOut

# ヘルパー
from rc_timeseries_helpers import (
    infer_device_from_model,
    extract_states_time_major,
    extract_logits_time_major,
    apply_time_selection,
    prepare_time_distributed_targets,
    compute_loss_time_kept,
    sequence_accuracy_majority_vote
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
label_encoder = preprocessing.LabelEncoder()

In [2]:
import wandb
# WandBの設定
WANDB_API_KEY = "2d996a98ef8dddefa91d675f85b5efd96fb911ae"  # あなたのWandB APIキーをここに入力してください

wandb.login(key = WANDB_API_KEY)

[34m[1mwandb[0m: Currently logged in as: [33mnekodaisuki169[0m ([33mdoctor_thesis_material[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /Users/sota/.netrc


True

In [3]:
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset

class TactileSequenceDataset(Dataset):
    def __init__(self, root_dir, dataset_params):
        super().__init__()
        self.root_dir = Path(root_dir)
        self.seq_start = int(dataset_params["seq_start"])
        self.seq_end   = int(dataset_params["seq_end"])
        self.dtype     = torch.float32
        self.mmap_mode = dataset_params.get("mmap_mode", None)

        # ★追加：チャネル選択指定
        self.taxel_indices = dataset_params.get("taxel_indices", None)
        self.axis_indices  = dataset_params.get("axis_indices", None)
        self.channel_mask  = dataset_params.get("channel_mask", None)  # (n_taxels,3) bool

        # クラスディレクトリ列挙
        self.class_names = sorted([d.name for d in self.root_dir.iterdir() if d.is_dir()])
        if not self.class_names:
            raise RuntimeError(f"No class directories under {self.root_dir}")

        self.class_to_idx = {c: i for i, c in enumerate(self.class_names)}
        self.num_classes  = len(self.class_names)

        self.samples = []
        for class_name in self.class_names:
            class_dir = self.root_dir / class_name
            for npy_path in sorted(class_dir.glob("*.npy")):
                self.samples.append((npy_path, self.class_to_idx[class_name]))

        if not self.samples:
            raise RuntimeError(f"No .npy files found under {self.root_dir}")

        self.labels = [lab for _, lab in self.samples]

        # 先頭で形状チェック
        arr0 = np.load(self.samples[0][0], mmap_mode=self.mmap_mode)
        if arr0.ndim != 3:
            raise ValueError(f"Expected (T, n_taxels, 3), got shape={arr0.shape}")
        T, n_taxels, axes = arr0.shape
        if axes != 3:
            raise ValueError(f"Last dim must be 3, got {axes}")
        if self.seq_end > T:
            raise ValueError(f"seq_end({self.seq_end}) > T({T}). Please adjust.")

        self.original_T  = T
        self.n_taxels    = n_taxels
        self.seq_len     = self.seq_end - self.seq_start

        # ---- ★選択設定の正規化 & feature_dim 計算 ----
        if self.channel_mask is not None:
            m = np.asarray(self.channel_mask, dtype=bool)
            if m.shape != (self.n_taxels, 3):
                raise ValueError(f"channel_mask must be shape {(self.n_taxels,3)} but got {m.shape}")
            self.flat_mask = m.reshape(-1)  # (n_taxels*3,)
            self.selected_channels = int(self.flat_mask.sum())
            if self.selected_channels == 0:
                raise ValueError("channel_mask selects 0 channels.")
        else:
            if self.taxel_indices is None:
                self.taxel_indices = np.arange(self.n_taxels, dtype=int)
            else:
                self.taxel_indices = np.asarray(self.taxel_indices, dtype=int)
                if self.taxel_indices.min() < 0 or self.taxel_indices.max() >= self.n_taxels:
                    raise ValueError(f"taxel_indices out of range 0..{self.n_taxels-1}: {self.taxel_indices}")

            if self.axis_indices is None:
                self.axis_indices = np.arange(3, dtype=int)
            else:
                self.axis_indices = np.asarray(self.axis_indices, dtype=int)
                if self.axis_indices.min() < 0 or self.axis_indices.max() >= 3:
                    raise ValueError(f"axis_indices out of range 0..2: {self.axis_indices}")

            self.selected_channels = int(len(self.taxel_indices) * len(self.axis_indices))

        # ★重要：ここが ESN の input_size になる
        self.feature_dim = self.selected_channels

        print("=== TactileSequenceDataset initialized ===")
        print(f"root_dir     : {self.root_dir}")
        print(f"num_classes  : {self.num_classes}")
        print(f"num_samples  : {len(self.samples)}")
        print(f"original T   : {self.original_T}")
        print(f"seq range    : [{self.seq_start}, {self.seq_end}) -> seq_len={self.seq_len}")
        print(f"n_taxels     : {self.n_taxels}")
        print(f"selected_channels : {self.selected_channels}  (taxel*axis after selection)")
        print(f"feature_dim  : {self.feature_dim}")
        if self.channel_mask is not None:
            print("selection    : channel_mask used")
        else:
            print(f"selection    : taxel_indices={self.taxel_indices.tolist()} axis_indices={self.axis_indices.tolist()}")
        print("class_to_idx : ", self.class_to_idx)
        print("=========================================")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        npy_path, label_idx = self.samples[idx]

        arr = np.load(npy_path, mmap_mode=self.mmap_mode)  # (T, n_taxels, 3)
        seq = arr[self.seq_start:self.seq_end]             # (seq_len, n_taxels, 3)
        seq = np.asarray(seq, dtype=np.float32)

        # ★チャネル選択
        if self.channel_mask is not None:
            flat = seq.reshape(self.seq_len, -1)          # (L, n_taxels*3)
            flat = flat[:, self.flat_mask]                # (L, selected_channels)
            seq2 = flat
        else:
            sel = seq[:, self.taxel_indices, :][:, :, self.axis_indices]  # (L, k_taxels, k_axes)
            seq2 = sel.reshape(self.seq_len, -1)                          # (L, k_taxels*k_axes)

        # x: (1, seq_len, feature_dim)
        x = torch.from_numpy(seq2).to(self.dtype).unsqueeze(0)

        # y: class index (long)
        y = torch.tensor(label_idx, dtype=torch.long)
        return x, y



In [4]:
from torch.utils.data import Subset, DataLoader
from sklearn.model_selection import StratifiedKFold

def _unwrap_subset(ds):
    """Subset(Subset(...)) でも base dataset と base index 列に展開して返す"""
    idxs = list(range(len(ds)))
    while isinstance(ds, Subset):
        idxs = [ds.indices[i] for i in idxs]
        ds = ds.dataset
    return ds, idxs

def _get_stratify_labels(ds):
    """
    StratifiedKFold 用 y を返す（dataset.labels を最優先）。
    labels が無ければ samples/targets から取る。最終手段は __getitem__ で推定（遅い）。
    """
    base, base_idxs = _unwrap_subset(ds)

    if hasattr(base, "labels"):
        base_labels = list(base.labels)
        return [base_labels[i] for i in base_idxs]

    if hasattr(base, "samples"):
        base_labels = [lab for _, lab in base.samples]
        return [base_labels[i] for i in base_idxs]

    if hasattr(base, "targets"):
        base_labels = list(base.targets)
        return [base_labels[i] for i in base_idxs]

    # fallback（遅い）：getitem から推定（y が index/one-hot どちらでもOK）
    labels = []
    for i in range(len(ds)):
        y = ds[i][1]
        if torch.is_tensor(y):
            if y.dtype.is_floating_point:  # one-hot
                labels.append(int(y.argmax().item()))
            else:  # index
                labels.append(int(y.view(-1)[0].item()))
        else:
            labels.append(int(y))
    return labels

def create_cross_validation_dataloaders(dataset, dataset_params, traing_params):
    """
    - dataset が Dataset / Subset / Subsetのネスト いずれでもOK
    - dataset.labels を最優先で stratify を作る
    - 正規化はしない（Datasetをそのまま読む）
    - “フルバッチ” ローダを返す（既存方針維持）
    """
    n_splits   = int(traing_params["n_splits"])
    seed       = int(traing_params.get("seed", dataset_params.get("seed", 0)))

    y = _get_stratify_labels(dataset)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=seed)

    loaders = []
    for tr_idx, va_idx in skf.split(range(len(dataset)), y):
        tr = Subset(dataset, tr_idx)
        va = Subset(dataset, va_idx)

        # 既存の「フルバッチ」方針を維持（batch_size は無視して全件）
        loaders.append((
            DataLoader(tr, len(tr), shuffle=True),
            DataLoader(va, len(va), shuffle=False),
        ))

    return loaders


def prepare_datasets(dataset_params, traing_params, data_dir):
    testdata_ratio = float(traing_params["testdata_ratio"])
    batch_size = int(dataset_params["batch_size"])
    
    # データセットの準備
    dataset = TactileSequenceDataset(data_dir, dataset_params)  # 拡張はここでは適用しない

    # データセットを学習用、テスト用に分割する
    test_size = int(len(dataset) * testdata_ratio)
    train_size = len(dataset) - test_size
    crossval_dataset, test_dataset = random_split(dataset, [train_size, test_size])

    # 交差検証のためのデータローダーを準備
    cross_val_loaders = create_cross_validation_dataloaders(crossval_dataset, dataset_params, traing_params)

    # テストデータローダーの準備
    test_loader = DataLoader(test_dataset, batch_size, shuffle=False)

    return cross_val_loaders, test_loader

In [8]:
def train_model(model, criterion, optimizer, train_loader, dataset_params, model_params, traing_params):
    model.train()

    num_epochs = int(traing_params["num_epochs"])
    batch_training = bool(model_params["Batch_Training"])
    Regularization_L2 = float(model_params["Regularization_L2"])

    device = infer_device_from_model(model, fallback="cpu")

    washout_steps = int(dataset_params.get("washout_steps", 0))
    time_stride   = int(dataset_params.get("time_stride", 1))

    for epoch in range(num_epochs):
        loss_sum = 0.0
        acc_sum  = 0.0
        n_seen   = 0

        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)
            labels = labels.to(device)

            # ---- forward (ESNは1回だけ) ----
            esn_out = model.ESN(inputs)         # [B,1,T,H] or [B,T,H]
            ro_out  = model.ReadOut(esn_out)    # [B,1,T,C] or [B,T,C]

            # ---- normalize shapes to time-major ----
            states_BTH = extract_states_time_major(esn_out)  # [B,T,H]
            logits_BTC = extract_logits_time_major(ro_out)   # [B,T,C]

            if states_BTH.shape[1] != logits_BTC.shape[1]:
                raise ValueError(f"T mismatch: states T={states_BTH.shape[1]}, logits T={logits_BTC.shape[1]}")

            # ---- washout / stride ----
            states_BTH = apply_time_selection(states_BTH, washout_steps=washout_steps, time_stride=time_stride)
            logits_BTC = apply_time_selection(logits_BTC, washout_steps=washout_steps, time_stride=time_stride)

            B, T_eff, H = states_BTH.shape
            _, _, C = logits_BTC.shape

            # ---- targets（リッジ更新用）を作る ----
            # ※ここで返るlossは「更新前」なので batch_training では使わない
            _, targets_onehot_BTC, target_index_B = compute_loss_time_kept(
                logits_time_major=logits_BTC,
                labels=labels,
                criterion=criterion
            )

            # ---- update ----
            if batch_training:
                # ridge update requires 2D: X [H, B*T], Y [C, B*T]
                X = states_BTH.permute(2, 0, 1).contiguous().view(H, -1)         # [H, B*T]
                Y = targets_onehot_BTC.permute(2, 0, 1).contiguous().view(C, -1) # [C, B*T]
                model.ReadOut.ridge_regression_update(X, Y, model, Regularization_L2)

                # ★重要：更新後のReadOutで logits を作り直す（＝回帰後の精度を反映）
                ro_out_after  = model.ReadOut(esn_out)
                logits_after  = extract_logits_time_major(ro_out_after)
                logits_after  = apply_time_selection(logits_after, washout_steps=washout_steps, time_stride=time_stride)

                # 回帰後 logits で loss/acc を計算
                loss, _, _ = compute_loss_time_kept(
                    logits_time_major=logits_after,
                    labels=labels,
                    criterion=criterion
                )
                acc = sequence_accuracy_majority_vote(logits_after, target_index_B)

            else:
                # backprop学習の場合（従来通り）
                loss, _, target_index_B = compute_loss_time_kept(
                    logits_time_major=logits_BTC,
                    labels=labels,
                    criterion=criterion
                )
                optimizer.zero_grad(set_to_none=True)
                loss.backward()
                optimizer.step()

                acc = sequence_accuracy_majority_vote(logits_BTC, target_index_B)

            loss_sum += float(loss.item()) * B
            acc_sum  += float(acc) * B
            n_seen   += B

        epoch_loss = loss_sum / max(n_seen, 1)
        epoch_acc  = acc_sum  / max(n_seen, 1)

        print(f"Epoch {epoch+1}/{num_epochs} | Loss: {epoch_loss:.4f} | Acc: {epoch_acc*100:.2f}%")
        try:
            wandb.log({"loss": epoch_loss, "train_accuracy": epoch_acc})
        except Exception:
            pass

    print("Training complete")


def validate_model(model, val_loader, dataset_params, model_params, criterion):
    model.eval()

    device = infer_device_from_model(model, fallback="cpu")
    washout_steps = int(dataset_params.get("washout_steps", 0))
    time_stride   = int(dataset_params.get("time_stride", 1))

    loss_sum = 0.0
    acc_sum  = 0.0
    n_seen   = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.float().to(device)
            labels = labels.to(device)

            esn_out = model.ESN(inputs)
            ro_out  = model.ReadOut(esn_out)

            states_BTH = extract_states_time_major(esn_out)
            logits_BTC = extract_logits_time_major(ro_out)

            if states_BTH.shape[1] != logits_BTC.shape[1]:
                raise ValueError(f"T mismatch: states T={states_BTH.shape[1]}, logits T={logits_BTC.shape[1]}")

            states_BTH = apply_time_selection(states_BTH, washout_steps=washout_steps, time_stride=time_stride)
            logits_BTC = apply_time_selection(logits_BTC, washout_steps=washout_steps, time_stride=time_stride)

            B = logits_BTC.shape[0]

            loss, _, target_index_B = compute_loss_time_kept(
                logits_time_major=logits_BTC,
                labels=labels,
                criterion=criterion
            )
            acc = sequence_accuracy_majority_vote(logits_BTC, target_index_B)

            loss_sum += float(loss.item()) * B
            acc_sum  += float(acc) * B
            n_seen   += B

    val_loss = loss_sum / max(n_seen, 1)
    val_acc  = acc_sum  / max(n_seen, 1)

    print(f"Validation | Loss: {val_loss:.4f} | Acc: {val_acc*100:.2f}%")
    try:
        wandb.log({"val_loss": val_loss, "val_accuracy": val_acc})
    except Exception:
        pass

    return val_loss, val_acc
    
def test_model(model, val_loader):
    model.eval()  # Set model to evaluation mode
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.float().to(device)
            labels = labels.squeeze().to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
#             _, label_preds = torch.max(labels, 1)
            label_preds = labels

            val_running_loss += loss.item() * inputs.size(0)
            print(loss.item())
            val_running_corrects += torch.sum(preds == label_preds)

    val_loss = val_running_loss / len(val_loader.dataset)
    val_accuracy = val_running_corrects.double() / len(val_loader.dataset)
    
    print(f'Test Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}')
    wandb.log({"test_accuracy": val_accuracy, "test_loss": val_loss})

def compute_accuracy(model_output, target, n_taus):
    # モデルの出力と教師ラベルを各データに分割
    split_model_output = torch.split(model_output.squeeze(), n_taus, dim=-1)
    split_target = torch.split(target.squeeze(), n_taus, dim=-1)
    # print("aaaa")
    # print(split_model_output[0].shape)
    # print(split_target[0].shape)
    correct = 0
    total = 0
    
    for pred, true_label in zip(split_model_output, split_target):
        # 最も確率が高いラベルを予測ラベルとして取得
        # print(pred.shape)
        # print(pred)
        # print(true_label.shape)
        # print(true_label)
        count_ones = (true_label == 1).sum().item()
        # print(count_ones)
        histgram_predict = torch.bincount(torch.max(pred, 0)[1])
        _, predicted = torch.max(histgram_predict, 0)
    
        histgram_true_label_idx = torch.bincount(torch.max(true_label, 0)[1])
        _, true_label_idx = torch.max(histgram_true_label_idx, 0)

        # 正解数をカウント
        correct += (predicted == true_label_idx).sum().item()
        total += 1

    # 精度を算出
    accuracy = correct / total
    return accuracy

def model_params_candinate(model_params):
    model_params_combinations = list(itertools.product(*model_params.values()))
    param_dicts = [dict(zip(model_params.keys(), combination)) for combination in model_params_combinations]
    return param_dicts

# モデル構造を辞書型に格納
def model_sturcture_dict(model):
    layers_dict = {}
    for name, module in model.named_modules():
        layers_dict[name] = {
            'type': type(module).__name__,
            'parameters': {p: getattr(module, p) for p in module.__dict__ if not p.startswith('_')}
        }
    # モデル名と初期の引数は削除
    del(layers_dict[''])
    return layers_dict
    

In [6]:
# from torch.utils.data import Subset

# def base_indices(ds):
#     idxs = list(range(len(ds)))
#     while isinstance(ds, Subset):
#         idxs = [ds.indices[i] for i in idxs]
#         ds = ds.dataset
#     return ds, idxs

# def paths_from_loader(loader):
#     base, idxs = base_indices(loader.dataset)
#     return set(str(base.samples[i][0]) for i in idxs)

# for fold, (tr, va) in enumerate(cross_val_loaders):
#     inter = paths_from_loader(tr) & paths_from_loader(va)
#     print(f"fold={fold} overlap_paths={len(inter)}")

NameError: name 'cross_val_loaders' is not defined

In [10]:
#各種のパラメータ設定
dataset_params = {"seq_start" : 400, "seq_end" : 1200, "sequence_length": 800, "slicing_size" : 1, "augmentation_factor": 0, "batch_size" : 32, "Onehot_Encoding" : None, "augmentation_mu" : 0, "augmentation_sigma" : 0, "augmentation_shift" : 1, "time_stride":1,  "taxel_indices": [0], "axis_indices": [0, 1, 2]}
model_params = {"reservoir_size" : [100],"input_size" : [None], "channel_size" : [1],  "reservoir_weights_scale" : [1], "input_weights_scale" : [1, 0.01, 10], "spectral_radius" : [0.9],"reservoir_density" : [0.02], "leak_rate" : [0.3, 0.5, 0.9], "Batch_Training" : [True], "ReadOut_output_size" : [None], "Regularization_L2" : [0.5]}
training_params = {"num_epochs" : 1, "learning_rate" : 0.01, "weight_decay" : 1e-2, "testdata_ratio" : 0, "n_splits" : 5}

#それぞれのモデルパラメータ候補を辞書に格納する
model_params = model_params_candinate(model_params)

#学習データセットの設定
data_dir="./normalized_dataset/20251215_161803/"
cross_val_loaders, test_loader = prepare_datasets(dataset_params, training_params, data_dir)

# ★ここで「選択した軸/セル」に基づいて自動上書き（wandb.init より前）
_meta_ds = TactileSequenceDataset(data_dir, dataset_params)
dataset_params["input_size"] = int(_meta_ds.feature_dim)          # 選択チャネル数
dataset_params["ReadOut_output_size"] = int(_meta_ds.num_classes) # クラス数

# wandb.init(project="uskin_test_ESN", config=config_dictionary)

#モデルパラメータの候補ごとに，総当たりしてパラメータを探索する

for each_model_params in model_params:

    # ★wandb.init の前に必ず上書き
    each_model_params["input_size"] = int(dataset_params["input_size"])
    each_model_params["ReadOut_output_size"] = int(dataset_params["ReadOut_output_size"])

    
#     model = LSTMModel(each_model_params).to(device)
    model = ESN(each_model_params, training_params, dataset_params).to(device)
    model_sturcture = model_sturcture_dict(model)
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr = float(training_params["learning_rate"]), weight_decay = float(training_params["weight_decay"]))
    
    config_dictionary = {
    "dataset": data_dir,
    "dataset_params" : dataset_params,
    "architecture": model.__class__.__name__,
    "model_params" : each_model_params,
    "model_sturcture" : model_sturcture,
    "traing_params" : training_params,
    "criterion" : str(criterion),
    "optimizer" : str(optimizer),
    }

    wandb.init(project="uskin_test_NewESN", config=config_dictionary)

    # 4. k-fold交差検証のループ
    for fold, (train_loader, val_loader) in enumerate(cross_val_loaders):
        print(f'FOLD {fold}')
        print('--------------------------------')
        model = ESN(each_model_params, training_params, dataset_params).to(device)
        criterion = nn.MSELoss()
        optimizer = optim.Adam(model.parameters(), lr = float(training_params["learning_rate"]), weight_decay = float(training_params["weight_decay"]))
        train_model(model, criterion, optimizer, train_loader, dataset_params, each_model_params, training_params)
        validate_model(model, val_loader, dataset_params, each_model_params, criterion)
        # model.__init__(each_model_params, training_params, dataset_params)

    print('CrossVaridation Finished')
    print('--------------------------------')
    #テストデータによる評価
    # テストデータローダーの準備

#     test_model(model, test_loader)
    
    wandb.finish()

=== TactileSequenceDataset initialized ===
root_dir     : normalized_dataset/20251215_161803
num_classes  : 25
num_samples  : 250
original T   : 1255
seq range    : [400, 1200) -> seq_len=800
n_taxels     : 16
selected_channels : 3  (taxel*axis after selection)
feature_dim  : 3
selection    : taxel_indices=[0] axis_indices=[0, 1, 2]
class_to_idx :  {'01_table_cover': 0, '02_fur_scarf': 1, '03_washing_towel': 2, '04_carpet1': 3, '05_bubble_wrap': 4, '06_fleece_scarf': 5, '07_knit_hat1': 6, '08_body_towel1': 7, '09_body_towel2': 8, '10_carpet2': 9, '11_work_gloves': 10, '12_knit_hat2': 11, '13_toilet_mat1': 12, '14_floor_mat': 13, '15_sponge1': 14, '16_printed_tatami': 15, '17_cushion1': 16, '18_mop': 17, '19_toilet_mat2': 18, '20_fleece_sock': 19, '21_cushion': 20, '22_carpet3': 21, '23_fleece_mat': 22, '24_carpet4': 23, '25_sponge2': 24}
=== TactileSequenceDataset initialized ===
root_dir     : normalized_dataset/20251215_161803
num_classes  : 25
num_samples  : 250
original T   : 1255


FOLD 0
--------------------------------
predicted_index_B + tensor([14, 12,  0,  8,  8,  4, 18, 11, 20, 11, 23,  4, 22, 23, 21, 21, 17, 10,
        10,  3, 21, 19, 23, 18,  7, 12,  5, 13, 12, 17, 13, 19, 12, 24,  7,  1,
        10, 23, 12, 17, 20, 17, 16,  8, 21,  5, 19,  9, 17, 17, 10, 14, 17, 15,
         9,  7, 20,  5, 18, 20, 12, 11, 13,  5,  9, 18,  7, 13, 15,  1, 22, 23,
        11, 24, 22, 10,  1,  4, 20,  4,  0,  0, 17, 20, 15, 12, 14, 13, 11, 24,
         7, 19,  5, 14, 13, 10,  4, 17,  8, 12,  0, 18, 11,  6, 20,  3, 21,  4,
         7, 12, 20, 15, 12, 12,  0, 20,  8,  3,  1, 17,  0, 19, 15,  7, 19,  4,
         6, 22,  2, 24, 17, 21,  9, 10,  3,  5, 24, 11,  0, 12, 18, 16, 24, 23,
         8, 16,  8, 16, 10, 21, 13,  1,  3, 16, 19, 17,  9,  1,  0,  3,  4, 22,
        15, 18,  9, 12, 14, 24, 17, 21,  3,  5, 22, 18,  1, 23, 13,  5,  6,  9,
        17, 15, 17, 16,  3, 20, 22, 23, 19, 20, 11,  9,  8,  1,  7, 16, 16, 24,
        15, 22])
target_index_per_sequence + tensor([14,  6,

VBox(children=(Label(value='0.001 MB of 0.028 MB uploaded\r'), FloatProgress(value=0.04329494035376388, max=1.…

0,1
loss,▇▄▄▁█
train_accuracy,▆▆▅▁█
val_accuracy,▁██▆█
val_loss,█▁▄▅▅

0,1
loss,0.02136
train_accuracy,0.94
val_accuracy,0.92
val_loss,0.02164


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011169687500000691, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([ 4, 10, 19, 23, 12,  3,  5, 23, 10, 18, 11, 16, 21,  9, 17,  9, 20, 12,
        12,  7, 12, 24,  4,  1, 22,  7,  5, 10, 17, 20, 16, 10, 11, 21, 22, 17,
        13, 11,  1, 20, 19, 18,  3, 21, 24,  5, 21,  1,  4, 13,  8, 13, 16,  5,
        23, 20,  5, 23,  8, 19, 24, 15, 16, 19,  0, 22,  0, 16,  1, 22, 13,  7,
         1, 11,  8,  4, 17, 20, 12, 17, 15, 19, 24,  0, 20,  3, 18, 12, 12, 21,
        10,  9, 14, 18,  5, 11,  8, 24, 20, 12,  5, 20,  8,  1, 20, 24,  0, 16,
         0,  4, 22, 17, 12, 13, 20, 12, 17, 18,  8,  8, 23, 13,  4,  7, 20, 10,
         4, 12, 11, 22, 17, 17, 23, 15, 10, 19, 12, 23, 18, 16, 10, 24,  7, 10,
         4, 16, 21, 20,  9,  0, 15, 17, 19,  3,  0,  7, 12, 21,  3, 17, 12, 20,
        20, 17, 15,  7, 17, 13,  7,  9, 17, 15, 23, 11, 18, 19,  3, 24, 15,  9,
        12, 20, 12, 17, 11,  8, 18, 17, 15, 21,  5, 22,  1, 22,  9,  1,  9,  3,
        13,  0])
target_index_per_sequence + tensor([ 4, 10,

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,▅▄█▄▁
train_accuracy,▁▂▁█▃
val_accuracy,▁█▆█▅
val_loss,█▁▄▇▃

0,1
loss,0.02172
train_accuracy,0.89
val_accuracy,0.88
val_loss,0.022


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168902777778081, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([10,  5, 10,  1,  4,  1, 19,  0, 23, 22, 14, 23, 17, 18, 16, 12, 11, 21,
        19, 12, 18, 20,  4,  8, 10, 21, 19, 21,  0, 20,  0, 11,  9, 23, 17,  8,
        12, 18, 17,  4, 23,  4, 13,  8, 12,  3, 18, 10, 12,  8, 10, 18, 16, 24,
        20, 17, 12,  5, 18, 13, 19, 16,  7, 22, 15,  5,  1, 15, 16,  5,  4,  0,
         3, 23,  5, 20,  7, 11, 19, 23, 17, 17,  5,  8, 11,  0,  9, 12,  1, 12,
         1, 24,  1, 11, 13, 20, 15,  3, 17, 17, 17, 24,  7, 20,  7, 10,  9, 12,
        10,  3, 21, 20, 11, 17, 19,  3, 12, 16, 20,  8, 13, 22, 21,  9, 20,  4,
        11, 12, 12,  7, 24, 22, 20, 13,  3, 17,  9, 22,  3, 24,  4, 17, 24,  7,
        21, 15,  0, 17, 18, 20, 15, 15, 24, 22, 17, 22, 16, 20,  1,  9, 13, 10,
        19, 24, 21, 17, 23, 23,  5,  0, 20, 17,  5, 12, 13,  3, 15, 12,  9, 18,
         8,  7, 12,  1,  4,  8, 16, 22, 11, 20, 12, 19, 21, 16, 15, 20,  0, 13,
         9,  7])
target_index_per_sequence + tensor([10,  5,

VBox(children=(Label(value='0.001 MB of 0.018 MB uploaded\r'), FloatProgress(value=0.06523760330578512, max=1.…

0,1
loss,▁█▇▂▅
train_accuracy,▅▆█▁▁
val_accuracy,▁▆▅█▅
val_loss,█▁▆▅▅

0,1
loss,0.02277
train_accuracy,0.875
val_accuracy,0.88
val_loss,0.02304


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011156914355554666, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([16, 16, 16, 16, 11, 22,  0, 16,  4,  4,  0, 16, 15,  0, 15, 24, 15, 16,
        16,  0,  0,  4,  4,  4, 15, 16,  5, 16, 11,  4, 16, 16,  0, 24, 16, 16,
         4, 22,  4, 15,  4, 24, 15,  0, 15,  0,  4, 24, 24, 15,  0, 11, 16, 15,
         4, 15,  4, 15, 16,  4, 15, 11,  0, 15, 11, 16, 16,  5, 16,  4,  4, 16,
        16,  0, 15,  0, 15,  4,  0, 11,  0,  0,  0, 24, 16,  0,  0, 22,  0, 16,
         5,  4,  0, 16,  4, 24, 16,  4, 24, 11, 11, 11, 24,  4, 22, 15, 15, 24,
         0, 22, 11, 15, 16, 11, 16, 15, 16, 22, 16,  0,  4, 11,  0, 15, 16, 16,
         0, 16, 16, 24, 24, 15,  4,  4, 16, 16, 11, 15, 15, 15, 24, 16,  4, 16,
        16, 22, 16,  4, 24,  0,  4, 24,  4, 15, 11, 11,  4, 16,  4, 24, 16,  4,
        16, 16, 24,  0, 16, 15, 11, 11, 11, 16,  4, 16, 15,  0, 15,  0,  4,  4,
         0, 11, 16,  0, 24, 11, 16, 16, 11,  4, 16, 16, 15,  0, 16,  4, 16, 16,
         0,  0])
target_index_per_sequence + tensor([16, 16,

VBox(children=(Label(value='0.001 MB of 0.028 MB uploaded\r'), FloatProgress(value=0.04328455396003975, max=1.…

0,1
loss,▁▇▆▆█
train_accuracy,▇▁▆█▅
val_accuracy,▁███▆
val_loss,█▁▁▅▃

0,1
loss,0.03609
train_accuracy,0.205
val_accuracy,0.2
val_loss,0.03609


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011141190744444885, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([15, 15,  0, 16, 11,  4, 11,  4, 15, 11, 11,  4, 16, 16, 24,  0, 15, 16,
        24,  4,  0,  4,  0,  4, 24, 24,  0, 15, 16, 22, 16, 16,  4, 24, 11, 22,
         0, 16,  4, 16,  4, 16, 16, 15,  4, 15, 24,  4, 24, 24, 16, 15, 11,  0,
         4,  0, 16,  0, 24,  0, 24, 11,  4,  4,  4, 11,  0,  0, 16, 22,  0, 16,
         4, 16,  0, 16, 16, 15, 11,  0,  0, 16, 15, 15,  4,  0,  0, 11, 24, 16,
        16,  4, 24, 24,  0,  4, 11, 15, 16, 11,  4, 15, 16, 24, 15, 15,  5,  0,
        24, 15, 22, 15, 24, 11,  4, 16, 16, 11, 16, 16,  0, 15,  0,  0, 15, 16,
         0, 16, 16, 16,  4,  0, 22, 22, 16, 16, 15, 16,  4,  4, 24, 15, 11,  0,
         0, 16,  0, 16, 24,  4,  4, 16, 16, 15,  4,  0,  4,  4, 16, 15, 15,  0,
        16, 16,  0, 11, 15, 16,  4, 16, 15,  4,  4, 11, 16, 22, 11, 15, 16,  4,
        24,  5, 16, 16,  5,  4, 11,  0, 16, 15,  0, 24, 16, 16, 16, 16, 15, 16,
        16, 16])
target_index_per_sequence + tensor([15,  5,

VBox(children=(Label(value='0.001 MB of 0.028 MB uploaded\r'), FloatProgress(value=0.04328752099256263, max=1.…

0,1
loss,▁█▆▅▅
train_accuracy,█▁██▅
val_accuracy,▁██▆▃
val_loss,█▁▁▄▂

0,1
loss,0.0361
train_accuracy,0.215
val_accuracy,0.2
val_loss,0.03611


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011128443511111325, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([16, 16, 16,  0,  4, 11,  4, 24, 16, 11, 15, 15, 16,  0,  4, 11, 15,  0,
        24,  0, 24,  0, 16, 22, 24,  4, 16, 16, 24, 22,  0,  0, 15,  4, 16,  4,
        16,  0, 16, 16, 11,  4, 15,  0,  4, 16, 15, 11, 16,  0, 16,  0, 16,  0,
         4,  0, 15,  0, 16, 15, 11, 15, 24,  4,  5, 16, 24, 24, 16,  4, 15,  0,
         4, 24,  4,  4,  4, 24, 24, 16, 16,  4, 15,  0, 22, 24, 24, 11, 24,  4,
         5, 11, 16,  4,  0,  4,  4, 16, 16, 15, 16,  0, 15,  4, 16,  4, 24, 16,
        15, 16,  4, 16, 16, 16,  0, 16, 16, 16, 11, 16, 24, 11,  0,  5, 24,  4,
        15,  0,  0,  0, 16, 15, 16, 16, 16, 16, 11, 15, 16, 16,  4,  4, 16,  0,
        16, 15, 16, 16,  4, 16, 24, 16,  0,  4, 15, 15, 11,  0,  0, 15, 16,  4,
        22, 24,  4, 15,  0, 16, 11, 24, 15,  0, 22,  4,  4, 16, 16,  4, 24, 24,
        15, 15, 22, 11,  0, 15, 16, 15,  4, 16, 24,  0,  0,  0, 15, 16,  4, 15,
        16, 15])
target_index_per_sequence + tensor([16, 16,

VBox(children=(Label(value='0.001 MB of 0.028 MB uploaded\r'), FloatProgress(value=0.04328603742545754, max=1.…

0,1
loss,▁▆█▂▅
train_accuracy,█▁▇▆▃
val_accuracy,▁██▅█
val_loss,█▁▂▄▂

0,1
loss,0.03612
train_accuracy,0.215
val_accuracy,0.24
val_loss,0.03613


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011117926855556308, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([24, 24, 23, 12, 23, 14, 21,  1,  7, 21, 16,  9, 17, 23, 19,  1, 21, 24,
        22, 14, 18, 11,  9,  0, 24,  7, 23,  5,  7, 18, 17, 16,  0,  8, 20,  1,
        16,  5, 23, 22, 23, 20, 17, 18, 20,  7, 11, 10, 23, 17,  7, 23,  9,  1,
        21,  8,  4,  6,  4, 23, 22, 16, 13, 17, 15,  9,  9, 15, 14, 24,  6,  4,
        15, 24, 18,  5, 13,  1, 12, 10, 23, 11, 22,  0, 10,  4,  8, 13, 10, 16,
        10,  0, 17, 24, 23,  6, 14,  8,  9,  8, 11, 24,  8, 13,  5, 13, 17, 15,
         6,  1, 19, 12, 15, 14, 21, 19,  9, 20, 17, 11, 19,  4, 23, 11, 22, 12,
         0, 20, 18, 22,  8, 18,  6,  6,  5, 19, 23, 18, 17, 13,  5, 19, 19, 17,
        12, 17, 22, 17, 14, 12,  7, 21, 13, 23,  6, 17, 21,  1,  5, 14, 24, 15,
         4,  4, 11,  0, 11, 16, 12,  8, 20, 20,  0, 19, 22,  0,  5,  6, 15, 21,
        20, 16, 10, 17,  9, 17, 18,  0, 17,  1,  4, 15, 10, 12, 23, 16, 17, 13,
        10, 10])
target_index_per_sequence + tensor([24, 24,

VBox(children=(Label(value='0.001 MB of 0.016 MB uploaded\r'), FloatProgress(value=0.07483557504295787, max=1.…

0,1
loss,▃▂██▁
train_accuracy,██▃▆▁
val_accuracy,▃█▁▁▃
val_loss,▄▁▆█▃

0,1
loss,0.02288
train_accuracy,0.89
val_accuracy,0.88
val_loss,0.02341


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011168007866665953, max=1.0…

FOLD 0
--------------------------------
predicted_index_B + tensor([ 5, 13, 10, 21,  7, 20, 23,  1, 22, 19,  8, 11,  6,  9, 22,  4,  0, 12,
        11, 24, 15,  9,  8, 12,  8, 12,  7, 16, 23, 24,  6, 12, 21,  8, 17, 20,
        23, 17, 11,  9, 18, 24, 21, 20,  5, 14, 11, 13,  1, 17, 18, 24, 13, 16,
         4, 19, 15,  7, 11,  4, 20, 23, 15, 22, 23, 13, 23,  0, 12, 10, 15,  1,
         8, 20,  5, 12, 10, 17,  4, 22, 18, 17, 10, 19,  1, 23, 22,  0,  1,  9,
        11, 20, 11, 23, 10, 14,  9, 23,  8, 13,  1, 13, 19, 22,  8, 12, 10,  5,
        17, 24, 21,  5,  1, 16, 23, 23, 10, 19, 18, 14,  5, 24, 17,  6, 15,  9,
        23, 19, 12, 15, 18, 23, 10,  8, 13, 17,  6,  8, 24, 11, 24, 20, 16, 24,
        18,  0, 17, 17,  1,  0, 23,  6, 21, 17, 17, 20, 21, 17, 10, 14, 17, 14,
         0, 16,  7,  0, 15, 22, 15,  6,  5, 16,  4, 19, 13,  4,  4,  0, 17, 16,
        17, 18, 21, 22,  9, 19, 21, 18, 17,  4, 24, 16, 17, 23,  5,  6, 17, 14,
        17,  9])
target_index_per_sequence + tensor([ 5, 13,

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
loss,█▃█▁▅
train_accuracy,▁▁▁█▅
val_accuracy,▁█▃▁▁
val_loss,▆▁██▅

0,1
loss,0.02427
train_accuracy,0.89
val_accuracy,0.86
val_loss,0.02472


FOLD 0
--------------------------------
predicted_index_B + tensor([18, 17,  4, 11, 13, 13, 17, 17, 16,  1, 16,  8, 10, 13, 22, 21, 15, 12,
        24,  9, 13, 11,  5, 16, 16, 23,  7,  6,  1, 14, 23,  8, 22, 23, 12, 11,
         8, 17, 15, 19, 19, 23, 18,  0,  4,  0, 17, 20,  5, 12, 11, 23, 21, 21,
         4, 13, 15,  6, 17, 20,  1,  0,  8, 14,  4, 17, 11,  6, 24, 23, 19,  9,
        15, 18, 17, 17, 20,  5, 17, 15,  9,  1, 14, 23,  7,  7, 21, 17, 20, 10,
        21,  9, 17,  4, 11,  9,  7, 22, 22, 17,  1, 12, 12,  8, 10, 12, 20,  0,
         5,  1, 21,  8,  7, 14, 17,  6, 15, 10, 12, 21, 10, 20,  9, 19, 17,  6,
         0, 21, 14, 16, 18, 24, 20, 20,  8, 17,  5, 16,  8, 24, 23, 18, 22, 23,
         1, 11, 14, 13, 15,  9, 23,  5,  0, 18, 17,  4, 19, 23, 22,  6, 22, 23,
         8, 22, 19, 17,  0, 10,  0,  1, 23, 24, 24, 19, 15, 23, 24, 10,  4, 18,
        19, 12, 24, 15,  9, 13, 10,  5, 19, 16, 17, 24, 24,  5, 16, 11, 10, 18,
        23,  4])
target_index_per_sequence + tensor([18, 17,

VBox(children=(Label(value='0.001 MB of 0.028 MB uploaded\r'), FloatProgress(value=0.043290488431876606, max=1…

0,1
loss,▁█▆▄▅
train_accuracy,█▁▅▅█
val_accuracy,▁█▁█▁
val_loss,▁▄▄█▅

0,1
loss,0.02563
train_accuracy,0.88
val_accuracy,0.86
val_loss,0.02614
