In [None]:
import os
import glob
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from scipy.signal import resample
from sklearn.model_selection import train_test_split
from tqdm import tqdm

# =========================================================
# 0. 設定エリア (ここだけ自分のPC環境に合わせて書き換えてください)
# =========================================================
# 例: データの親フォルダ
BASE_DIR = r"C:\Users\fujiw\OneDrive\デスクトップ\ECG_ResNet"

# 各フォルダの場所 (BASE_DIRからの相対パス、または絶対パスで指定)
# ※ フォルダ名はあなたの実際の環境に合わせて変更してください
STAGE2_DIR = os.path.join(BASE_DIR, "stage2")       # .npyファイルが入っているフォルダ
CSV_DIR    = os.path.join(BASE_DIR, "train_csvs")   # 正解データ(.csv)が入っているフォルダ
TRAIN_META = os.path.join(BASE_DIR, "train.csv")    # train.csv のパス
SAVE_DIR   = os.path.join(BASE_DIR, "models")       # モデル(.pth)の保存先

# PCのスペックに合わせて変更
BATCH_SIZE = 64  # GPUメモリ不足になる場合は 32 に下げてください
EPOCHS = 200
PATIENCE = 20
LR = 1e-3

# =========================================================
# 1. Dataset Class (処理ロジック変更なし)
# =========================================================
class ECGDatasetRam(Dataset):
    def __init__(self, df, npy_dir, csv_dir, target_len=5000):
        self.target_len = target_len
        self.samples = [] 
        
        # パス存在チェック
        if not os.path.exists(npy_dir):
            raise FileNotFoundError(f"Directory not found: {npy_dir}")
            
        target_ids = set(df['id'].astype(str).tolist())
        file_list = []
        all_files = glob.glob(os.path.join(npy_dir, "*.npy"))
        
        print(f"Scanning files for {len(target_ids)} IDs in {npy_dir}...")
        for fpath in all_files:
            fname = os.path.basename(fpath)
            file_id = fname.split('-')[0]
            if file_id in target_ids:
                file_list.append((fpath, file_id))
        
        print(f"Found {len(file_list)} files. Loading into RAM...")

        # メモリへの一括読み込み
        for fpath, sample_id in tqdm(file_list, desc="Pre-loading Data"):
            processed_data = self.process_one_file(fpath, sample_id, csv_dir)
            if processed_data is not None:
                self.samples.append(processed_data)
                
        print(f"Successfully loaded {len(self.samples)} samples into RAM.")

    def process_one_file(self, npy_path, sample_id, csv_dir):
        try:
            # Input Loading
            data = np.load(npy_path)
            data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
            
            original_len = data.shape[1]
            if data.ndim != 2 or data.shape[0] != 13:
                 return None

            # Reconstruct
            reconstructed = np.zeros((12, original_len), dtype=np.float32)
            for i in range(4):
                sig_row = data[i]
                id_row = data[9+i]
                unique_ids = np.unique(id_row)
                for uid in unique_ids:
                    if 0 <= uid <= 11:
                        mask_ch = (id_row == uid)
                        reconstructed[int(uid), mask_ch] = sig_row[mask_ch]
            
            # Target Loading
            csv_path = os.path.join(csv_dir, f"{sample_id}.csv")
            if not os.path.exists(csv_path):
                return None

            target_df = pd.read_csv(csv_path)
            target_vals = target_df.values.T 

            # Mask Creation
            mask_data = (~np.isnan(target_vals)).astype(np.float32)
            target_data = np.nan_to_num(target_vals, nan=0.0, posinf=0.0, neginf=0.0)
            
            # Resampling
            if reconstructed.shape[1] != self.target_len:
                input_final = resample(reconstructed, self.target_len, axis=1)
            else:
                input_final = reconstructed
                
            if target_data.shape[1] != self.target_len:
                target_final = resample(target_data, self.target_len, axis=1)
                mask_final = resample(mask_data, self.target_len, axis=1)
            else:
                target_final = target_data
                mask_final = mask_data
            
            # Cleanup
            input_final = np.nan_to_num(input_final, nan=0.0).astype(np.float32)
            target_final = np.nan_to_num(target_final, nan=0.0).astype(np.float32)
            mask_final = (np.nan_to_num(mask_final, nan=0.0) > 0.5).astype(np.float32)
            
            return (input_final, target_final, mask_final, original_len)
            
        except Exception:
            return None

    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        input_arr, target_arr, mask_arr, orig_len = self.samples[idx]
        return (torch.from_numpy(input_arr), 
                torch.from_numpy(target_arr), 
                torch.from_numpy(mask_arr), 
                torch.tensor(orig_len, dtype=torch.long))

# =========================================================
# 2. Model (処理ロジック変更なし)
# =========================================================
class ResNet1d_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.enc1 = nn.Sequential(nn.Conv1d(12, 64, 7, 2, 3), nn.BatchNorm1d(64), nn.ReLU())
        self.enc2 = nn.Sequential(nn.Conv1d(64, 128, 3, 2, 1), nn.BatchNorm1d(128), nn.ReLU())
        self.enc3 = nn.Sequential(nn.Conv1d(128, 256, 3, 2, 1), nn.BatchNorm1d(256), nn.ReLU())
        self.enc4 = nn.Sequential(nn.Conv1d(256, 512, 3, 2, 1), nn.BatchNorm1d(512), nn.ReLU())
        
        self.dec4 = nn.Sequential(nn.Conv1d(512 + 256, 256, 3, 1, 1), nn.BatchNorm1d(256), nn.ReLU())
        self.dec3 = nn.Sequential(nn.Conv1d(256 + 128, 128, 3, 1, 1), nn.BatchNorm1d(128), nn.ReLU())
        self.dec2 = nn.Sequential(nn.Conv1d(128 + 64, 64, 3, 1, 1), nn.BatchNorm1d(64), nn.ReLU())
        self.dec1 = nn.Sequential(nn.Conv1d(64, 32, 3, 1, 1), nn.BatchNorm1d(32), nn.ReLU())
        self.final = nn.Conv1d(32, 12, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        d4 = torch.cat([torch.nn.functional.interpolate(e4, size=e3.shape[2]), e3], dim=1)
        d4 = self.dec4(d4)
        d3 = torch.cat([torch.nn.functional.interpolate(d4, size=e2.shape[2]), e2], dim=1)
        d3 = self.dec3(d3)
        d2 = torch.cat([torch.nn.functional.interpolate(d3, size=e1.shape[2]), e1], dim=1)
        d2 = self.dec2(d2)
        d1 = torch.nn.functional.interpolate(d2, size=x.shape[2])
        d1 = self.dec1(d1)
        out = self.final(d1)
        return out

# =========================================================
# 3. Training Loop (ローカル対応版)
# =========================================================
def run_training():
    # デバイス自動判定
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using Device: {device}")
    
    # フォルダ作成
    if not os.path.exists(SAVE_DIR):
        os.makedirs(SAVE_DIR)
    
    if not os.path.exists(TRAIN_META):
        raise FileNotFoundError(f"Metadata file not found: {TRAIN_META}")

    df = pd.read_csv(TRAIN_META)
    unique_ids = df['id'].unique()
    train_ids, val_ids = train_test_split(unique_ids, test_size=0.1, random_state=42)
    
    train_df = df[df['id'].isin(train_ids)].reset_index(drop=True)
    val_df = df[df['id'].isin(val_ids)].reset_index(drop=True)
    
    print("Initializing Training Dataset...")
    train_dataset = ECGDatasetRam(train_df, npy_dir=STAGE2_DIR, csv_dir=CSV_DIR)
    print("Initializing Validation Dataset...")
    val_dataset = ECGDatasetRam(val_df, npy_dir=STAGE2_DIR, csv_dir=CSV_DIR)
    
    # ローカル(Windows)では num_workers=0 が安定します (RAMロード済みなので速度差はほぼありません)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
    
    model = ResNet1d_UNet().to(device)
    optimizer = optim.Adam(model.parameters(), lr=LR)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=7)
    
    # AMP設定 (GPUがある場合のみ有効化)
    use_amp = torch.cuda.is_available()
    scaler = torch.amp.GradScaler('cuda') if use_amp else None
    
    criterion_raw = nn.MSELoss(reduction='none')

    best_loss = float('inf')
    patience_counter = 0
    
    print("Training Started...")
    for epoch in range(EPOCHS):
        current_epoch = epoch + 1
        print(f"\n{'='*20} Epoch {current_epoch}/{EPOCHS} {'='*20}")
        
        # --- Train ---
        model.train()
        train_loss = 0
        train_pbar = tqdm(train_loader, desc=f"Training   (Epoch {current_epoch})")
        
        for inputs, targets, masks, _ in train_pbar:
            inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
            
            optimizer.zero_grad()
            
            # GPUならAMP使用、CPUなら通常計算
            if use_amp:
                with torch.amp.autocast('cuda'):
                    outputs = model(inputs)
                    raw_loss = criterion_raw(outputs, targets)
                    masked_loss = raw_loss * masks 
                    loss = masked_loss.sum() / (masks.sum() + 1e-8)
                
                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                outputs = model(inputs)
                raw_loss = criterion_raw(outputs, targets)
                masked_loss = raw_loss * masks 
                loss = masked_loss.sum() / (masks.sum() + 1e-8)
                loss.backward()
                optimizer.step()
            
            train_loss += loss.item()
            train_pbar.set_postfix({'loss': f"{loss.item():.6f}"})
            
        avg_train_loss = train_loss / len(train_loader)
        
        # --- Valid ---
        model.eval()
        val_loss = 0
        debug_batch_data = None
        
        with torch.no_grad():
            for i, (inputs, targets, masks, orig_lens) in enumerate(val_loader):
                inputs, targets, masks = inputs.to(device), targets.to(device), masks.to(device)
                
                if use_amp:
                    with torch.amp.autocast('cuda'):
                        outputs = model(inputs)
                        raw_loss = criterion_raw(outputs, targets)
                        masked_loss = raw_loss * masks
                        loss = masked_loss.sum() / (masks.sum() + 1e-8)
                else:
                    outputs = model(inputs)
                    raw_loss = criterion_raw(outputs, targets)
                    masked_loss = raw_loss * masks
                    loss = masked_loss.sum() / (masks.sum() + 1e-8)
                
                val_loss += loss.item()
                
                if i == 0:
                    debug_batch_data = {
                        'input': inputs.cpu().float().numpy(),
                        'target': targets.cpu().float().numpy(),
                        'output': outputs.cpu().float().numpy(),
                        'mask': masks.cpu().float().numpy(),
                        'length': orig_lens.numpy()
                    }
        
        avg_val_loss = val_loss / len(val_loader)
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f"Result: [Epoch {current_epoch}] Train Loss: {avg_train_loss:.6f} | Val Loss: {avg_val_loss:.6f} | LR: {current_lr:.1e}")
        
        scheduler.step(avg_val_loss)
        
        if avg_val_loss < best_loss:
            print(f"score improved: {best_loss:.6f} --> {avg_val_loss:.6f}")
            best_loss = avg_val_loss
            patience_counter = 0
            
            # モデル保存
            save_path = os.path.join(SAVE_DIR, "best_resnet1d_unet.pth")
            torch.save(model.state_dict(), save_path)
            
            if debug_batch_data:
                debug_path = os.path.join(SAVE_DIR, "best_validation_debug.npz")
                np.savez(debug_path, **debug_batch_data)
        else:
            patience_counter += 1
            print(f"No improvement. Patience: {patience_counter}/{PATIENCE}")
            if patience_counter >= PATIENCE:
                print("Early Stopping!")
                break
                
    print("All Finished!")

if __name__ == "__main__":
    run_training()