In [None]:
import os, random, numpy as np, torch, torch.nn as nn, torch.utils.data as data
from tqdm import tqdm
import torch.optim as optim
from torch.cuda.amp import autocast, GradScaler

# ===============================
# ÏÑ§Ï†ï
# ===============================
SAVE_PATH = "/content/drive/MyDrive/ntu_best_model.pth"
ALLOWED_ACTIONS = ['A001','A002','A008','A009','A011','A012','A028','A029','A030','A041','A103','A104']
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)

# ===============================
# Dataset (12 ÎùºÎ≤® ÌïÑÌÑ∞ + Ïû¨Ïù∏Îç±Ïã± + velocity)
# ===============================
class NTUSkeletonDataset(data.Dataset):
    def __init__(self, files, T=180):
        self.files = files
        self.T = T
        self.allowed_actions = ALLOWED_ACTIONS
        self.action2label = {a:i for i,a in enumerate(self.allowed_actions)}
        # NTU25 bone pairs (1-based)
        self.bone_pairs = [(1,2),(2,21),(3,21),(4,3),(5,21),(6,5),(7,6),(8,7),
                        (9,21),(10,9),(11,10),(12,11),(13,1),(14,13),(15,14),
                        (16,15),(17,1),(18,17),(19,18),(20,19),(22,23),(23,8),
                        (24,25),(25,12)]

    def __len__(self): return len(self.files)

    def _pad(self, arr):
        T,V,C = arr.shape
        out = np.zeros((self.T,V,C),dtype=np.float32)
        if T >= self.T:
            # Ï§ëÏïô ÌÅ¨Î°≠(Í∏∏Ïù¥ Îã§ÏñëÏÑ±Ïóê Í∞ïÌï®)
            s = (T - self.T)//2
            out[:] = arr[s:s+self.T]
        else:
            out[:T] = arr
        return out

    def _normalize(self, J):
        # Ìûô Ï§ëÏã¨(0Î≤à) Ïù¥Îèô
        J = J - J[:,0:1,:]
        # Ïä§ÏºÄÏùº: Ïñ¥Íπ®Ìè≠(5,9) Ïö∞ÏÑ†, Ïã§Ìå®Ïãú Í≥®Î∞òÌè≠(13,17)
        shoulder = np.linalg.norm(J[:,4:5,:] - J[:,8:9,:], axis=-1) # 0-based
        pelvis   = np.linalg.norm(J[:,12:13,:] - J[:,16:17,:], axis=-1)
        denom = np.where(np.isfinite(shoulder) & (shoulder>1e-6), shoulder, pelvis)
        s = denom.mean()
        if not np.isfinite(s) or s < 1e-6: s = 1.0
        return J / s

    def _bone(self, J):
        B = np.zeros_like(J)
        for a,b in self.bone_pairs:
            B[:, b-1] = J[:, b-1] - J[:, a-1]
        return B

    def _vel(self, J):
        Vv = np.zeros_like(J)
        Vv[1:] = J[1:] - J[:-1]
        return Vv

    def __getitem__(self, idx):
        path = self.files[idx]
        d = np.load(path, allow_pickle=True).item()
        J = d['skel_body0'].astype('float32')      # (T,V,3)
        J = self._pad(J)
        J = self._normalize(J)
        B = self._bone(J)
        Vv= self._vel(J)

        # (C,T,V)
        J  = np.transpose(J , (2,0,1))
        B  = np.transpose(B , (2,0,1))
        Vv = np.transpose(Vv, (2,0,1))

        base = os.path.basename(path)
        Acode = None
        for a in self.allowed_actions:
            if a in base: Acode = a; break
        y = self.action2label[Acode]

        return torch.tensor(J), torch.tensor(B), torch.tensor(Vv), torch.tensor(y, dtype=torch.long)

# ===============================
# Split (ÎûúÎç§ 8:1:1) + Cross-Subject ÏòµÏÖò
# ===============================
def list_allowed_files(root):
    all_files = [os.path.join(root,f) for f in os.listdir(root) if f.endswith('.npy')]
    files = [f for f in all_files if any(a in f for a in ALLOWED_ACTIONS)]
    print(f"‚úÖ ÏÇ¨Ïö©Ìï† Îç∞Ïù¥ÌÑ∞ Í∞úÏàò: {len(files)} / {len(all_files)}")
    return files

def parse_ids(fname):
    # S001C002P003R001A004.npy
    base = os.path.basename(fname)
    S = int(base.split('S')[1].split('C')[0])
    C = int(base.split('C')[1].split('P')[0])
    P = int(base.split('P')[1].split('R')[0])
    return S,C,P

def split_random(files, train=0.8, val=0.1, test=0.1, seed=SEED):
    rng = np.random.default_rng(seed)
    idx = np.arange(len(files)); rng.shuffle(idx)
    n=len(files); n_tr=int(n*train); n_val=int(n*val)
    tr = [files[i] for i in idx[:n_tr]]
    va = [files[i] for i in idx[n_tr:n_tr+n_val]]
    te = [files[i] for i in idx[n_tr+n_val:]]
    print(f"TRAIN: {len(tr)} | VAL: {len(va)} | TEST: {len(te)}")
    return tr,va,te

def split_cross_subject_811(files, seed=SEED):
    """
    ÏÇ¨Îûå(Subject) Í∏∞Ï§ÄÏúºÎ°ú 8:1:1 ÎπÑÏú®Î°ú Train/Val/Test Î∂ÑÌï†
    """
    # 1Ô∏è‚É£ ÏÇ¨Îûå ID (P###) Ï∂îÏ∂ú
    subs = sorted({parse_ids(f)[2] for f in files})
    rng = np.random.default_rng(seed)
    rng.shuffle(subs)

    # 2Ô∏è‚É£ ÏÇ¨Îûå Îã®ÏúÑÎ°ú Î∂ÑÌï† (8:1:1)
    n = len(subs)
    n_train = int(n * 0.8)
    n_val   = int(n * 0.1)
    train_subs = set(subs[:n_train])
    val_subs   = set(subs[n_train:n_train+n_val])
    test_subs  = set(subs[n_train+n_val:])

    # 3Ô∏è‚É£ ÌååÏùºÏùÑ ÏÇ¨Îûå Îã®ÏúÑÎ°ú Î∂ÑÎ∞∞
    tr = [f for f in files if parse_ids(f)[2] in train_subs]
    va = [f for f in files if parse_ids(f)[2] in val_subs]
    te = [f for f in files if parse_ids(f)[2] in test_subs]

    print(f"[Cross-Subject 8:1:1] TRAIN: {len(tr)} | VAL: {len(va)} | TEST: {len(te)}")
    return tr, va, te

# ===============================
# Î™®Îç∏: Í≤ΩÎüâ CTR-GCN Ïä§ÌÉÄÏùº + Dilated TCN + GroupNorm
# ===============================
class JointMix(nn.Module):
    """Learnable joint mixing across V dimension (VxV), CTR-GCNÏùò Ï±ÑÎÑê-Í≥µÍ∞Ñ ÌòºÌï© ÏïÑÏù¥ÎîîÏñ¥ Í≤ΩÎüâÌôî."""
    def __init__(self, V):
        super().__init__()
        self.W = nn.Parameter(torch.eye(V))  # (V,V)
    def forward(self, x):
        # x: (N,C,T,V)
        return torch.einsum('nctv,vw->nctw', x, self.W)

class TBlock(nn.Module):
    def __init__(self, in_ch, out_ch, k=9, dilation=1, groups=8, drop=0.1):
        super().__init__()
        pad = (k-1)//2 * dilation
        self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=(k,1), padding=(pad,0), dilation=(dilation,1))
        self.gn   = nn.GroupNorm(groups, out_ch)
        self.act  = nn.SiLU(inplace=True)
        self.drop = nn.Dropout(drop)
        self.res  = (in_ch==out_ch)
        self.proj = nn.Identity() if self.res else nn.Conv2d(in_ch,out_ch,1)
    def forward(self,x):
        y = self.conv(x)
        y = self.gn(y); y = self.act(y); y = self.drop(y)
        return self.proj(x) + y

class StreamNet(nn.Module):
    def __init__(self, in_ch=3, V=25, num_classes=12, width=(64,128,256)):
        super().__init__()
        C1,C2,C3 = width
        self.stem = nn.Conv2d(in_ch, C1, 1)
        self.mix1 = JointMix(V)
        self.t1a  = TBlock(C1, C1, dilation=1)
        self.t1b  = TBlock(C1, C1, dilation=2)
        self.mix2 = JointMix(V)
        self.t2a  = TBlock(C1, C2, dilation=2)
        self.t2b  = TBlock(C2, C2, dilation=4)
        self.mix3 = JointMix(V)
        self.t3a  = TBlock(C2, C3, dilation=2)
        self.t3b  = TBlock(C3, C3, dilation=4)
        self.head = nn.Linear(C3, num_classes)

    def forward(self, x):  # x: (N,C,T,V)
        x = self.stem(x)
        x = self.mix1(x); x = self.t1a(x); x = self.t1b(x)
        x = self.mix2(x); x = self.t2a(x); x = self.t2b(x)
        x = self.mix3(x); x = self.t3a(x); x = self.t3b(x)
        x = x.mean(dim=[2,3])  # GAP
        return self.head(x)

class ThreeStream(nn.Module):
    def __init__(self, num_classes=12, V=25):
        super().__init__()
        self.j = StreamNet(in_ch=3, V=25, num_classes=num_classes)
        self.b = StreamNet(in_ch=3, V=25, num_classes=num_classes)
        self.v = StreamNet(in_ch=3, V=25, num_classes=num_classes)
        self.fc = nn.Linear(num_classes*3, num_classes)
    def forward(self, J,B,Vv):
        pj = self.j(J); pb = self.b(B); pv = self.v(Vv)
        return self.fc(torch.cat([pj,pb,pv], dim=-1))

# ===============================
# ÌèâÍ∞Ä/ÌïôÏäµ Î£®ÌîÑ
# ===============================
@torch.no_grad()
def evaluate(model, loader, loss_fn, device, desc="VAL"):
    model.eval()
    tl, correct, n = 0.0, 0, 0
    for J,B,Vv,y in loader:
        J,B,Vv,y = J.to(device), B.to(device), Vv.to(device), y.to(device)
        with autocast():
            logits = model(J,B,Vv)
            loss = loss_fn(logits, y)
        tl += loss.item()
        pred = logits.argmax(1)
        correct += (pred==y).sum().item()
        n += y.size(0)
    avg = tl / max(1,len(loader))
    acc = correct / max(1,n)
    print(f"üìè {desc}  loss={avg:.4f}  acc={acc:.2%}")
    return avg, acc

def train(
    data_folder="/content/drive/MyDrive/ntu_rgbd_npy_backup/raw_npy",
    epochs=25, batch_size=4, accum_steps=4, T=180, use_cross_subject=False
):
    # Drive mount
    from google.colab import drive
    drive.mount('/content/drive', force_remount=True)
    os.makedirs(os.path.dirname(SAVE_PATH), exist_ok=True)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    pin = torch.cuda.is_available()

    files = list_allowed_files(data_folder)
    if use_cross_subject:
        tr_files, va_files, te_files = split_cross_subject_811(files)
    else:
        tr_files, va_files, te_files = split_random(files)

    train_set = NTUSkeletonDataset(tr_files, T=T)
    val_set   = NTUSkeletonDataset(va_files, T=T)
    test_set  = NTUSkeletonDataset(te_files, T=T)

    train_loader = data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=pin)
    val_loader   = data.DataLoader(val_set,   batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=pin)
    test_loader  = data.DataLoader(test_set,  batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=pin)

    num_classes = len(ALLOWED_ACTIONS)
    model = ThreeStream(num_classes=num_classes).to(device)
    opt = optim.AdamW(model.parameters(), lr=5e-4, weight_decay=0.05)
    loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
    scaler = GradScaler()

    # warmup + cosine
    def lr_lambda(cur_step):
        warm = max(1, int(0.1*epochs))*len(train_loader)
        if cur_step < warm:
            return (cur_step+1)/warm
        # cosine
        t = (cur_step - warm) / max(1,(epochs*len(train_loader)-warm))
        return 0.5*(1+np.cos(np.pi*t))
    sched = optim.lr_scheduler.LambdaLR(opt, lr_lambda)

    best_val = float('inf')
    global_step = 0
    for epoch in range(1, epochs+1):
        model.train()
        running = 0.0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}")
        opt.zero_grad(set_to_none=True)
        for i,(J,B,Vv,y) in enumerate(pbar):
            J,B,Vv,y = J.to(device, non_blocking=True), B.to(device, non_blocking=True), Vv.to(device, non_blocking=True), y.to(device, non_blocking=True)
            with autocast():
                logits = model(J,B,Vv)
                loss = loss_fn(logits, y) / accum_steps
            scaler.scale(loss).backward()

            if (i+1) % accum_steps == 0:
                scaler.unscale_(opt)
                nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad(set_to_none=True)
            running += loss.item()*accum_steps
            pbar.set_postfix(loss=f"{(loss.item()*accum_steps):.4f}")
            global_step += 1
            sched.step()

        train_loss = running / max(1,len(train_loader))

        val_loss, val_acc = evaluate(model, val_loader, loss_fn, device, "VAL")
        print(f"Epoch {epoch} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | val_acc={val_acc:.2%}")

        if val_loss < best_val:
            best_val = val_loss
            torch.save(model.state_dict(), SAVE_PATH)
            print(f"üíæ Best updated ‚Üí {SAVE_PATH} (val_loss={best_val:.4f})")

    # ÏµúÏ¢Ö ÌÖåÏä§Ìä∏
    model.load_state_dict(torch.load(SAVE_PATH, map_location=device))
    test_loss, test_acc = evaluate(model, test_loader, loss_fn, device, "TEST")
    print(f"‚úÖ Test done | loss={test_loss:.4f} acc={test_acc:.2%}")

if __name__ == "__main__":
    train(
        data_folder="/content/drive/MyDrive/ntu_rgbd_npy_backup/raw_npy",
        epochs=25, batch_size=4, accum_steps=4, T=180, use_cross_subject=False
    )
