In [1]:
def train_model(model, criterion, optimizer, train_loader, dataset_params, model_params, traing_params):
    model.train()  # Set model to training mode
    num_epochs = int(traing_params["num_epochs"])   
    reservoir_size = int(model_params["reservoir_size"])   
    input_size = int(model_params["input_size"])
    ReadOut_output_size = int(model_params["ReadOut_output_size"])
    sequence_length = int(int(dataset_params["sequence_length"]) / int(dataset_params["slicing_size"]))
    batch_training = model_params["Batch_Training"]
    Regularization_L2 = model_params["Regularization_L2"]

    for epoch in range(num_epochs):
        running_loss = 0.
        for inputs, labels in train_loader:
            inputs = inputs.float().to(device)
            # print(inputs.shape)
                
            labels = labels.to(device)
            
            optimizer.zero_grad()

            outputs_ESN = model.ESN(inputs)
            outputs = model.ReadOut(outputs_ESN)
            outputs = outputs.squeeze()
            
            B, T, C = outputs.shape                # B=200, T=800, C=25

            # ---- flattenの整合性を取る ----
            # リザバー： [500, B*T]
            outputs_ESN_flatten = outputs_ESN.view(reservoir_size, -1)
        
            # 出力： [C, B*T]
            outputs_flatten = outputs.view(C, -1)
        
            # ラベル：各サンプルの one-hot を T 回繰り返して [B, T, C] → [C, B*T]
            labels_rep = labels.unsqueeze(1).repeat(1, T, 1)  # [B, T, C]
            labels_flatten = labels_rep.view(C, -1)           # [C, B*T]
            # print(outputs_flatten.shape)
            # print(labels_flatten.shape)
        
            loss = criterion(outputs_flatten, labels_flatten)
            
            if batch_training == True:
                model.ReadOut.ridge_regression_update(outputs_ESN_flatten, labels_flatten, model, Regularization_L2)
            else:
                #勾配の計算
                loss.backward()
                #重みの更新
                optimizer.step()
            
            running_loss += loss.item() 

        epoch_loss = running_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}')
        wandb.log({"loss": epoch_loss})
    
    print('Training complete')
    
def validate_model(model, val_loader, dataset_params, model_params):
    model.eval()  # Set model to evaluation mode
    val_running_loss = 0.0
    val_running_corrects = 0
    
    reservoir_size = int(model_params["reservoir_size"])   
    input_size = int(model_params["input_size"])
    ReadOut_output_size = int(model_params["ReadOut_output_size"])
    sequence_length = int(int(dataset_params["sequence_length"]) / int(dataset_params["slicing_size"]))

    with torch.no_grad():
        for inputs, labels in val_loader:
            # print(inputs.shape)
            inputs = inputs.float().to(device)
            labels = labels.to(device)
            outputs_ESN = model.ESN(inputs)      # [B, 1, T, 500]
            outputs = model.ReadOut(outputs_ESN).squeeze(1)  # [B, T, C]

            B, T, C = outputs.shape

            outputs_flatten = outputs.view(C, -1)
            labels_rep = labels.unsqueeze(1).repeat(1, T, 1)
            labels_flatten = labels_rep.view(C, -1)

            loss = criterion(outputs_flatten, labels_flatten)
            val_running_loss += loss.item() * B

    val_loss = val_running_loss / len(val_loader.dataset)
    val_accuracy = compute_accuracy(outputs_flatten, labels_flatten, sequence_length)
    print(f"Accuracy: {val_accuracy * 100:.2f}%")
    
    print(f'Validation Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}')
    wandb.log({"val_accuracy": val_accuracy, "val_loss": val_loss})
    
def test_model(model, val_loader):
    model.eval()  # Set model to evaluation mode
    val_running_loss = 0.0
    val_running_corrects = 0

    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs = inputs.float().to(device)
            labels = labels.squeeze().to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, preds = torch.max(outputs, 1)
#             _, label_preds = torch.max(labels, 1)
            label_preds = labels

            val_running_loss += loss.item() * inputs.size(0)
            print(loss.item())
            val_running_corrects += torch.sum(preds == label_preds)

    val_loss = val_running_loss / len(val_loader.dataset)
    val_accuracy = val_running_corrects.double() / len(val_loader.dataset)
    
    print(f'Test Loss: {val_loss:.4f}, Accuracy: {val_accuracy:.4f}')
    wandb.log({"test_accuracy": val_accuracy, "test_loss": val_loss})

def compute_accuracy(model_output, target, n_taus):
    # モデルの出力と教師ラベルを各データに分割
    split_model_output = torch.split(model_output.squeeze(), n_taus, dim=-1)
    split_target = torch.split(target.squeeze(), n_taus, dim=-1)
    # print("aaaa")
    # print(split_model_output[0].shape)
    # print(split_target[0].shape)
    correct = 0
    total = 0
    
    for pred, true_label in zip(split_model_output, split_target):
        # 最も確率が高いラベルを予測ラベルとして取得
        # print(pred.shape)
        # print(pred)
        # print(true_label.shape)
        # print(true_label)
        count_ones = (true_label == 1).sum().item()
        # print(count_ones)
        histgram_predict = torch.bincount(torch.max(pred, 0)[1])
        _, predicted = torch.max(histgram_predict, 0)
    
        histgram_true_label_idx = torch.bincount(torch.max(true_label, 0)[1])
        _, true_label_idx = torch.max(histgram_true_label_idx, 0)

        # 正解数をカウント
        correct += (predicted == true_label_idx).sum().item()
        total += 1

    # 精度を算出
    accuracy = correct / total
    return accuracy

def model_params_candinate(model_params):
    model_params_combinations = list(itertools.product(*model_params.values()))
    param_dicts = [dict(zip(model_params.keys(), combination)) for combination in model_params_combinations]
    return param_dicts

# モデル構造を辞書型に格納
def model_sturcture_dict(model):
    layers_dict = {}
    for name, module in model.named_modules():
        layers_dict[name] = {
            'type': type(module).__name__,
            'parameters': {p: getattr(module, p) for p in module.__dict__ if not p.startswith('_')}
        }
    # モデル名と初期の引数は削除
    del(layers_dict[''])
    return layers_dict
    

In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import random

# ---------- 環境 ----------
device = torch.device("cpu")  # 形状デバッグはまずCPUでOK

# wandb が未初期化で落ちるのを避ける（最小モック）
class _WandbMock:
    def log(self, d): print("[wandb.log]", d)
wandb = _WandbMock()

# train_model/validate_model が呼ぶ unpack を「そのまま返す」にする
def unpack_batch_for_lyon(batch):
    return batch

# ---------- 仮想 Dataset（可変長） ----------
class FakeLyonDataset(Dataset):
    def __init__(self, N=12, max_T=80, F=16, C=5, seed=0):
        self.N = N
        self.max_T = max_T
        self.F = F
        self.C = C
        random.seed(seed)

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        T = random.randint(self.max_T // 2, self.max_T)  # 可変長
        x = torch.randn(T, self.F)                        # (Ti, F)
        label_idx = random.randint(0, self.C - 1)
        y = torch.zeros(self.C)
        y[label_idx] = 1.0                                # one-hot (C,)
        return x, y, T, label_idx

def collate_fake_lyon(batch):
    xs, ys, lens, label_idxs = zip(*batch)
    B = len(xs)
    T = max(lens)
    F = xs[0].shape[1]
    C = ys[0].shape[0]

    # inputs: (B,1,T,F)
    inputs = torch.zeros(B, 1, T, F)
    mask   = torch.zeros(B, T, dtype=torch.bool)
    for i, (x, L) in enumerate(zip(xs, lens)):
        inputs[i, 0, :L, :] = x
        mask[i, :L] = True

    y = torch.stack(list(ys), dim=0)  # (B,C)

    # labels を dict にして、あなたの dict 分岐を必ず通す
    labels = {
        "y": y,
        "mask": mask,
        "lengths": torch.tensor(lens, dtype=torch.long),
        "label_idx": torch.tensor(label_idxs, dtype=torch.long),
    }

    # あえて lengths/mask/label_idx は None で返して、train側で dict から拾う経路をテスト
    return (inputs, labels, None, None, None, {})

# ---------- 仮想 Model（形状だけ合わせる） ----------
class FakeESN(nn.Module):
    def __init__(self, reservoir_size):
        super().__init__()
        self.reservoir_size = reservoir_size

    def forward(self, inputs):
        # inputs: (B,1,T,F) を想定
        B, _, T, _ = inputs.shape
        # ESN状態っぽいテンソルを返す: (B,1,T,H)
        x = torch.randn(B, 1, T, self.reservoir_size, device=inputs.device)
        return x.contiguous()

class FakeReadOut(nn.Module):
    def __init__(self, reservoir_size, C):
        super().__init__()
        self.linear = nn.Linear(reservoir_size, C, bias=False)

    def forward(self, outputs_ESN):
        # outputs_ESN: (B,1,T,H)
        B, one, T, H = outputs_ESN.shape
        z = outputs_ESN.reshape(B*T, H)
        y = self.linear(z).reshape(B, 1, T, -1)
        return y.contiguous()

class FakeModel(nn.Module):
    def __init__(self, reservoir_size, C):
        super().__init__()
        self.ESN = FakeESN(reservoir_size)
        self.ReadOut = FakeReadOut(reservoir_size, C)

# ---------- ここから実行 ----------
C = 5
F = 16
reservoir_size = 32

train_ds = FakeLyonDataset(N=20, max_T=80, F=F, C=C, seed=1)
val_ds   = FakeLyonDataset(N=10, max_T=80, F=F, C=C, seed=2)

train_loader = DataLoader(train_ds, batch_size=4, shuffle=True,  collate_fn=collate_fake_lyon)
val_loader   = DataLoader(val_ds,   batch_size=4, shuffle=False, collate_fn=collate_fake_lyon)

model = FakeModel(reservoir_size=reservoir_size, C=C).to(device)

# criterion はあなたの train_model が引数で受け取るが、validate_model は global を参照しているので両方に合わせる
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)

dataset_params = {}
model_params = {
    "reservoir_size": reservoir_size,
    "Batch_Training": False,        # ←最小手間：ridge更新を通さない
    "Regularization_L2": 1e-4,
}
traing_params = {"num_epochs": 1}



In [3]:
# あなたの train_model / validate_model を「そのまま」呼ぶ
train_model(model, criterion, optimizer, train_loader, dataset_params, model_params, traing_params)
validate_model(model, val_loader, dataset_params, model_params)

KeyError: 'input_size'

In [8]:
import torch

def make_encoded_BTC(B, T, C, device="cpu"):
    # 値 = 1_000_000*b + 1_000*t + c  (c を下3桁に固定)
    b = torch.arange(B, device=device).view(B, 1, 1)
    t = torch.arange(T, device=device).view(1, T, 1)
    c = torch.arange(C, device=device).view(1, 1, C)
    return 1_000_000*b + 1_000*t + c  # [B,T,C]

def make_encoded_BTH(B, T, H, device="cpu"):
    # 値 = 1_000_000*b + 1_000*t + h  (h を下3桁に固定)
    b = torch.arange(B, device=device).view(B, 1, 1)
    t = torch.arange(T, device=device).view(1, T, 1)
    h = torch.arange(H, device=device).view(1, 1, H)
    return 1_000_000*b + 1_000*t + h  # [B,T,H]

def flatten_old_BTC(x):  # x: [B,T,C] -> [C, B*T] (旧: permuteなし)
    B, T, C = x.shape
    return x.view(C, -1)

def flatten_new_BTC(x):  # x: [B,T,C] -> [C, B*T] (新: permuteあり)
    B, T, C = x.shape
    return x.permute(2, 0, 1).contiguous().view(C, -1)

def flatten_old_BTH(x):  # x: [B,T,H] -> [H, B*T] (旧: permuteなし)
    B, T, H = x.shape
    return x.view(H, -1)

def flatten_new_BTH(x):  # x: [B,T,H] -> [H, B*T] (新: permuteあり)
    B, T, H = x.shape
    return x.permute(2, 0, 1).contiguous().view(H, -1)

def decode_bt_from_row(flat_row_1d):
    # flat_row_1d: [B*T]
    v = flat_row_1d
    b = (v // 1_000_000).to(torch.int64)
    t = ((v % 1_000_000) // 1_000).to(torch.int64)
    return b, t

def row_check_mod(flat, name, mod=1000, rows=3):
    # 期待: new flatten なら「行iの (value % 1000) は {i} だけ」
    print(f"\n[{name}] shape = {tuple(flat.shape)}")
    R = min(rows, flat.shape[0])
    for i in range(R):
        mods = torch.unique(flat[i] % mod).tolist()
        print(f"  row {i}: unique(value%{mod}) = {mods[:10]}{' ...' if len(mods)>10 else ''}")

def col_order_preview(flat, name, row=0, n=12):
    # 期待: new flatten なら (b,t) = (0,0..T-1),(1,0..T-1),...
    b, t = decode_bt_from_row(flat[row, :n])
    pairs = list(zip(b.tolist(), t.tolist()))
    print(f"\n[{name}] first {n} cols of row{row} decoded as (b,t):\n  {pairs}")

def compare_flatten(B=2, T=5, C=3, H=4, time_stride=1, device="cpu"):
    print(f"=== compare_flatten: B={B}, T={T}, C={C}, H={H}, time_stride={time_stride} ===")

    # --- 擬似 ReadOut 出力 [B,T,C] ---
    out_BTC = make_encoded_BTC(B, T, C, device=device)
    print(out_BTC.shape)
    print(out_BTC)

    # --- 擬似 ESN 状態 [B,T,H] ---
    st_BTH = make_encoded_BTH(B, T, H, device=device)

    # --- time_stride を新実装と同様に適用した版も作る ---
    if time_stride > 1:
        out_BTC_s = out_BTC[:, ::time_stride, :]
        st_BTH_s = st_BTH[:, ::time_stride, :]
    else:
        out_BTC_s = out_BTC
        st_BTH_s = st_BTH

    # =========================
    # 旧 flatten vs 新 flatten
    # =========================

    # (A) outputs: [B,T,C] -> [C,B*T]
    old_out = flatten_old_BTC(out_BTC_s)
    print(old_out.shape)
    print(old_out)
    new_out = flatten_new_BTC(out_BTC_s)
    print(new_out.shape)
    print(new_out)

    row_check_mod(old_out, "OLD outputs_flatten", mod=1000, rows=min(C,3))
    row_check_mod(new_out, "NEW outputs_flatten", mod=1000, rows=min(C,3))

    col_order_preview(old_out, "OLD outputs_flatten", row=0, n=12)
    col_order_preview(new_out, "NEW outputs_flatten", row=0, n=12)

    # (B) reservoir_states: [B,T,H] -> [H,B*T]
    old_st = flatten_old_BTH(st_BTH_s)
    new_st = flatten_new_BTH(st_BTH_s)

    row_check_mod(old_st, "OLD states_flatten", mod=1000, rows=min(H,3))
    row_check_mod(new_st, "NEW states_flatten", mod=1000, rows=min(H,3))

    col_order_preview(old_st, "OLD states_flatten", row=0, n=12)
    col_order_preview(new_st, "NEW states_flatten", row=0, n=12)

    print("\n--- Interpretation ---")
    print("NEW の row i は value%1000 が {i} だけ → 行が “c/h 固定” になっている（意図通り）")
    print("OLD は row i で value%1000 が複数 → 行が c/h 固定になっていない（軸が混線）")
    print("さらに NEW は列が (bごとに t が連続) だが、OLD は列順が崩れやすい")
    print("=> compute_accuracy の split（n_taus=T_eff）前提とも整合しやすいのは NEW")

# 実行例（まずは小さい値で）
compare_flatten(B=2, T=6, C=4, H=5, time_stride=1, device="cpu")

# stride あり版も比較（新実装の time_stride の影響確認）
compare_flatten(B=2, T=6, C=4, H=5, time_stride=2, device="cpu")


=== compare_flatten: B=2, T=6, C=4, H=5, time_stride=1 ===
torch.Size([2, 6, 4])
tensor([[[      0,       1,       2,       3],
         [   1000,    1001,    1002,    1003],
         [   2000,    2001,    2002,    2003],
         [   3000,    3001,    3002,    3003],
         [   4000,    4001,    4002,    4003],
         [   5000,    5001,    5002,    5003]],

        [[1000000, 1000001, 1000002, 1000003],
         [1001000, 1001001, 1001002, 1001003],
         [1002000, 1002001, 1002002, 1002003],
         [1003000, 1003001, 1003002, 1003003],
         [1004000, 1004001, 1004002, 1004003],
         [1005000, 1005001, 1005002, 1005003]]])
torch.Size([4, 12])
tensor([[      0,       1,       2,       3,    1000,    1001,    1002,    1003,
            2000,    2001,    2002,    2003],
        [   3000,    3001,    3002,    3003,    4000,    4001,    4002,    4003,
            5000,    5001,    5002,    5003],
        [1000000, 1000001, 1000002, 1000003, 1001000, 1001001, 1001002, 10010

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
def _to_BTH(outputs_ESN):
    """
    outputs_ESN: [B,1,T,H] or [B,T,H]
    return: states [B,T,H]
    """
    if outputs_ESN.dim() == 4:
        # [B,1,T,H]
        return outputs_ESN[:, 0]
    elif outputs_ESN.dim() == 3:
        # [B,T,H]
        return outputs_ESN
    else:
        raise ValueError(f"Unexpected ESN output shape: {outputs_ESN.shape}")

def _to_BTC(outputs):
    """
    outputs: [B,1,T,C] or [B,T,C]
    return: logits [B,T,C]
    """
    if outputs.dim() == 4:
        # [B,1,T,C]
        return outputs[:, 0]
    elif outputs.dim() == 3:
        # [B,T,C]
        return outputs
    else:
        raise ValueError(f"Unexpected ReadOut output shape: {outputs.shape}")

def _time_slice(x, washout=0, time_stride=1):
    """
    x: [B,T,...]
    """
    if washout < 0: washout = 0
    if time_stride < 1: time_stride = 1
    return x[:, washout::time_stride]

def _prepare_targets_and_indices(labels, B, T_eff, C, device):
    """
    labels の形式を吸収して、以下を返す：
      targets_BTC: [B,T_eff,C] float (ridge / MSE/BCE用)
      target_idx_B: [B] long (accuracy / CE用)
    対応する labels:
      - one-hot: [B,C]
      - index:  [B] or [B,1]
      - per-time one-hot: [B,T,C]（可変長や時刻ラベルがある場合）
      - per-time index:   [B,T]（同上）
    """
    # ---- per-time one-hot [B,T,C] ----
    if labels.dim() == 3 and labels.shape[0] == B and labels.shape[2] == C:
        targets = labels[:, :T_eff].float().to(device)
        target_idx = targets[:, 0].argmax(dim=-1).long()  # 代表として t=0 を使う
        return targets, target_idx

    # ---- per-time index [B,T] ----
    if labels.dim() == 2 and labels.shape[0] == B and labels.shape[1] == T_eff and labels.dtype in (torch.int64, torch.int32):
        target_idx_BT = labels.to(device)
        targets = F.one_hot(target_idx_BT, num_classes=C).float()  # [B,T,C]
        target_idx = target_idx_BT[:, 0].long()
        return targets, target_idx

    # ---- one-hot [B,C] ----
    if labels.dim() == 2 and labels.shape[0] == B and labels.shape[1] == C and labels.dtype.is_floating_point:
        target_idx = labels.argmax(dim=-1).long().to(device)
        targets = labels.to(device).unsqueeze(1).expand(B, T_eff, C).float()
        return targets, target_idx

    # ---- index [B] or [B,1] ----
    if labels.dim() == 1:
        target_idx = labels.long().to(device)
    elif labels.dim() == 2 and labels.shape[0] == B and labels.shape[1] == 1:
        target_idx = labels[:, 0].long().to(device)
    else:
        # ここに来るなら labels 形式が想定外
        raise ValueError(f"Unsupported labels shape/dtype: shape={labels.shape}, dtype={labels.dtype}")

    targets = F.one_hot(target_idx, num_classes=C).float().unsqueeze(1).expand(B, T_eff, C)  # [B,T,C]
    return targets, target_idx

def _loss_time_kept(logits_BTC, labels, criterion):
    """
    logits_BTC: [B,T,C]
    labels: one-hot [B,C] / idx [B] etc.
    criterion に合わせて、Tを保ったまま loss を作る
    返り値:
      loss (scalar)
      targets_BTC (float, [B,T,C])  # ridge 用にも使える
      target_idx_B (long, [B])
    """
    B, T, C = logits_BTC.shape
    targets_BTC, target_idx_B = _prepare_targets_and_indices(labels, B, T, C, logits_BTC.device)

    # CrossEntropyLoss 系なら index を使うのが自然
    if isinstance(criterion, torch.nn.CrossEntropyLoss):
        # CE は [B,C,T] と target [B,T] を受けられる
        logits_BCT = logits_BTC.transpose(1, 2)  # [B,C,T]
        target_BT = target_idx_B[:, None].expand(B, T)  # [B,T]（系列ラベルを各時刻に付与）
        loss = criterion(logits_BCT, target_BT)
        return loss, targets_BTC, target_idx_B

    # MSELoss / BCEWithLogitsLoss などは [B,T,C] のままいける
    loss = criterion(logits_BTC, targets_BTC)
    return loss, targets_BTC, target_idx_B

def sequence_accuracy_from_BTC(logits_BTC, target_idx_B):
    """
    logits_BTC: [B,T,C]
    target_idx_B: [B]
    予測は「各時刻の argmax を one-hot して T方向に加算」→多数決（旧実装のヒストグラム投票に相当）
    """
    B, T, C = logits_BTC.shape
    pred_BT = logits_BTC.argmax(dim=-1)                  # [B,T]
    vote_BC = F.one_hot(pred_BT, num_classes=C).sum(dim=1)  # [B,C]
    pred_B = vote_BC.argmax(dim=-1)                      # [B]
    return (pred_B == target_idx_B).float().mean().item()


In [13]:
a = torch.zeros(2, 3, 4, 5)
print(a[: ,0].shape)

torch.Size([2, 4, 5])
