In [1]:
import numpy as np
import os 
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [2]:
# a) データ読み込み直後に一度だけ計算しておく
all_vels = np.load("./dataset_one_batch/train_vels.npy")   # shape (800,1,70,70)
vel_min, vel_max = all_vels.min(), all_vels.max()

# b) Dataset 内で正規化／逆正規化
class NormWaveformDataset(Dataset):
    def __init__(self, waves_path, vels_path, vel_min, vel_max):
        self.waves = np.load(waves_path).astype(np.float32)
        self.vels  = np.load(vels_path).astype(np.float32)
        self.vmin, self.vmax = vel_min, vel_max

    def __len__(self): return len(self.waves)
    def __getitem__(self, idx):
        x = self.waves[idx]
        # 波形は global 統計で標準化しても良いです（ここでは簡易に min-max）
        x = (x - x.min()) / (x.max() - x.min())
        y = self.vels[idx]
        # 速度は [vel_min,vel_max] → [0,1]
        y_norm = (y - self.vmin) / (self.vmax - self.vmin)
        return torch.from_numpy(x).float(), torch.from_numpy(y_norm).float()

# c) DataLoader 作成（num_workers=0 推奨）
train_ds = NormWaveformDataset("./dataset_one_batch/train_waves.npy",
                               "./dataset_one_batch/train_vels.npy",
                               vel_min, vel_max)
val_ds   = NormWaveformDataset("./dataset_one_batch/val_waves.npy",
                               "./dataset_one_batch/val_vels.npy",
                               vel_min, vel_max)
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=16, shuffle=False, num_workers=2, pin_memory=True)
