In [None]:
from google.colab import drive
drive.mount('/content/drive')

TARGET_DIR = "/content/drive/MyDrive/your_path"
PROJECT_NAME = "quick-action-recognition"

resultsFolder = f"{TARGET_DIR}/{PROJECT_NAME}/"

import os
os.makedirs(resultsFolder, exist_ok=True)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import pickle
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.sampler import WeightedRandomSampler
import pandas as pd
import random
from torch.utils.data import Dataset
from torch.amp import autocast, GradScaler
from sklearn.metrics import confusion_matrix, classification_report, f1_score
from torch.optim.lr_scheduler import OneCycleLR

#########################################################
torch.cuda.empty_cache()
torch.backends.cudnn.benchmark = True

batch_size = 16
num_epoch = 200
lr = 3e-3
dropout = 0.2
weight_decay = 5e-5
experiment = "stgcn_onecyclelr"
resultsFile = experiment + ".csv"
state_path = resultsFolder + experiment + ".pth"
path = f"{TARGET_DIR}/{PROJECT_NAME}/data/NTU-RGB-D/x-view/"

resume_training = False  # True로 설정하면 이어서 학습
resume_epoch = 46


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

features_train = torch.FloatTensor(np.load(path + "small_train_data.npy"))
features_val   = torch.FloatTensor(np.load(path + "small_val_data.npy"))
features_test  = torch.FloatTensor(np.load(path + "small_test_data.npy"))

mean = features_train.mean(dim=(0, 2, 3, 4), keepdim=True)
std  = features_train.std(dim=(0, 2, 3, 4), keepdim=True) + 1e-5

features_train = (features_train - mean) / std
features_val   = (features_val   - mean) / std
features_test  = (features_test  - mean) / std


labels_train_all = pickle.load(open(path + "small_train_label.pkl", "rb"))   # shape (2, N)
labels_val_all   = pickle.load(open(path + "small_val_label.pkl", "rb"))  # shape (2, N)
labels_test_all  = pickle.load(open(path + "small_test_label.pkl", "rb"))    # shape (2, N)

labels_train = np.array(labels_train_all[1], dtype=int)
labels_val   = np.array(labels_val_all[1], dtype=int)
labels_test  = np.array(labels_test_all[1], dtype=int)


#########################################################
# 클래스 리라벨링: 0 ~ num_class-1 로 통일

actions = np.unique(labels_train)  # train에 등장하는 클래스 기준
label_map = {k: i for i, k in enumerate(actions)}

for old, new in label_map.items():
    labels_train[labels_train == old] = new
    labels_val[labels_val == old] = new
    labels_test[labels_test == old] = new

num_class = len(actions)

#################### Sampler

class_sample_count = np.array([len(np.where(labels_train == t)[0]) for t in np.unique(labels_train)])
weight = 1. / class_sample_count
samples_weight = np.array([weight[int(t)] for t in labels_train])
samples_weight = torch.from_numpy(samples_weight)
samples_weight = samples_weight.double()
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))

class_weights = torch.tensor(weight, dtype=torch.float32)
class_weights = class_weights / class_weights.mean()
class_weights = class_weights.to(device)

####################### Data Augmentation

def random_augment_skeleton(x, p_shift=0.3, p_scale=0.3, p_noise=0.3):
    """
    x: tensor of shape (C, T, V, M)
    """
    C, T, V, M = x.shape

    # 1) 시간축 랜덤 시프트
    if random.random() < p_shift:
        max_shift = max(1, T // 40)  # 전체 길이의 2.5% 정도까지
        shift = random.randint(-max_shift, max_shift)
        x = torch.roll(x, shifts=shift, dims=1)  # dim=1: T

    # 2) 전체 스켈레톤 스케일링
    #if random.random() < p_scale:
    #    scale = 1.0 + 0.1 * (2 * random.random() - 1)  # 0.9 ~ 1.1
    #    x = x * scale

    # 3) 좌표에 가우시안 노이즈 추가
    if random.random() < p_noise:
        noise = torch.randn_like(x) * 0.005  # sigma=0.005 정도
        x = x + noise

    return x
class SkeletonDataset(Dataset):
    def __init__(self, features, labels, augment=False):
        """
        features: tensor (N, C, T, V, M)
        labels: 1D tensor (N,)
        augment: train일 때만 True
        """
        self.features = features
        self.labels = labels
        self.augment = augment

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, idx):
        x = self.features[idx]          # (C, T, V, M)
        y = self.labels[idx]

        # train일 때만 augmentation
        if self.augment:
            x = random_augment_skeleton(x)

        return x, y

####################### Trainloader and testloader
labels_train_t = torch.LongTensor(labels_train)
labels_val_t   = torch.LongTensor(labels_val)
labels_test_t  = torch.LongTensor(labels_test)

# === 커스텀 Dataset 사용 ===
train_dataset = SkeletonDataset(features_train, labels_train_t, augment=True)   # ★ train만 augment=True
val_dataset   = SkeletonDataset(features_val,   labels_val_t,   augment=False)
test_dataset  = SkeletonDataset(features_test,  labels_test_t,  augment=False)

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    drop_last=True,
    num_workers=2,
    persistent_workers=True,
    sampler=sampler,                        # sampler는 그대로 사용
    pin_memory=(device.type == "cuda"),
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    drop_last=False,
    num_workers=2,
    shuffle=False,
    pin_memory=(device.type == "cuda"),
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    drop_last=False,
    num_workers=2,
    shuffle=False,
    pin_memory=(device.type == "cuda"),
)

# 메모리 아끼기
del features_train, labels_train_t
del features_val, labels_val_t
del features_test, labels_test_t

def get_edge():
    num_node = 25
    self_link = [(i, i) for i in range(num_node)]
    neighbor_1base = [(1, 2), (2, 21), (3, 21), (4, 3), (5, 21),
                      (6, 5), (7, 6), (8, 7), (9, 21), (10, 9),
                      (11, 10), (12, 11), (13, 1), (14, 13), (15, 14),
                      (16, 15), (17, 1), (18, 17), (19, 18), (20, 19),
                      (22, 23), (23, 8), (24, 25), (25, 12)]
    neighbor_link = [(i - 1, j - 1) for (i, j) in neighbor_1base]
    edge = self_link + neighbor_link
    center = 21 - 1
    return (edge, center)


def get_hop_distance(num_node, edge, max_hop=1):
    A = np.zeros((num_node, num_node))
    for i, j in edge:
        A[j, i] = 1
        A[i, j] = 1

    hop_dis = np.zeros((num_node, num_node)) + np.inf
    transfer_mat = [np.linalg.matrix_power(A, d) for d in range(max_hop + 1)]
    arrive_mat = (np.stack(transfer_mat) > 0)
    for d in range(max_hop, -1, -1):
        hop_dis[arrive_mat[d]] = d
    return hop_dis


def get_adjacency(hop_dis, center, num_node, max_hop, dilation):
    valid_hop = range(0, max_hop + 1, dilation)
    adjacency = np.zeros((num_node, num_node))
    for hop in valid_hop:
        adjacency[hop_dis == hop] = 1
    normalize_adjacency = adjacency
    A = []
    for hop in valid_hop:
        a_root = np.zeros((num_node, num_node))
        a_close = np.zeros((num_node, num_node))
        a_further = np.zeros((num_node, num_node))
        for i in range(num_node):
            for j in range(num_node):
                if hop_dis[j, i] == hop:
                    if hop_dis[j, center] == hop_dis[
                        i, center]:
                        a_root[j, i] = normalize_adjacency[j, i]
                    elif hop_dis[j,
                                 center] > hop_dis[i,
                                                   center]:
                        a_close[j, i] = normalize_adjacency[j, i]
                    else:
                        a_further[j, i] = normalize_adjacency[j, i]
        if hop == 0:
            A.append(a_root)
        else:
            A.append(a_root + a_close)
            A.append(a_further)
    A = np.stack(A)
    return (A)


layout = 'ntu-rgb+d',
strategy = 'spatial'
max_hop = 1
dilation = 1
num_node = 25
edge, center = get_edge()
hop_dis = get_hop_distance(num_node, edge, max_hop=max_hop)
A = get_adjacency(hop_dis, center, num_node, max_hop, dilation)
A = torch.tensor(A, dtype=torch.float32, requires_grad=False)


#######################################################################

class ConvTemporalGraphical(nn.Module):
    r"""The basic module for applying a graph convolution.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (int): Size of the graph convolving kernel
        t_kernel_size (int): Size of the temporal convolving kernel
        t_stride (int, optional): Stride of the temporal convolution. Default: 1
        t_padding (int, optional): Temporal zero-padding added to both sides of
            the input. Default: 0
        t_dilation (int, optional): Spacing between temporal kernel elements.
            Default: 1
        bias (bool, optional): If ``True``, adds a learnable bias to the output.
            Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 t_kernel_size=1,
                 t_stride=1,
                 t_padding=0,
                 t_dilation=1,
                 bias=True):
        super().__init__()

        self.kernel_size = kernel_size
        self.conv = nn.Conv2d(
            in_channels,
            out_channels * kernel_size,
            kernel_size=(t_kernel_size, 1),
            padding=(t_padding, 0),
            stride=(t_stride, 1),
            dilation=(t_dilation, 1),
            bias=bias)

    def forward(self, x, A):
        assert A.size(0) == self.kernel_size

        x = self.conv(x)

        n, kc, t, v = x.size()
        x = x.view(n, self.kernel_size, kc // self.kernel_size, t, v)
        x = torch.einsum('nkctv,kvw->nctw', (x, A))

        return x.contiguous(), A


######################################################################

class st_gcn(nn.Module):
    r"""Applies a spatial temporal graph convolution over an input graph sequence.
    Args:
        in_channels (int): Number of channels in the input sequence data
        out_channels (int): Number of channels produced by the convolution
        kernel_size (tuple): Size of the temporal convolving kernel and graph convolving kernel
        stride (int, optional): Stride of the temporal convolution. Default: 1
        dropout (int, optional): Dropout rate of the final output. Default: 0
        residual (bool, optional): If ``True``, applies a residual mechanism. Default: ``True``
    Shape:
        - Input[0]: Input graph sequence in :math:`(N, in_channels, T_{in}, V)` format
        - Input[1]: Input graph adjacency matrix in :math:`(K, V, V)` format
        - Output[0]: Outpu graph sequence in :math:`(N, out_channels, T_{out}, V)` format
        - Output[1]: Graph adjacency matrix for output data in :math:`(K, V, V)` format
        where
            :math:`N` is a batch size,
            :math:`K` is the spatial kernel size, as :math:`K == kernel_size[1]`,
            :math:`T_{in}/T_{out}` is a length of input/output sequence,
            :math:`V` is the number of graph nodes.
    """

    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 dropout=0,
                 residual=True):
        super().__init__()

        assert len(kernel_size) == 2
        assert kernel_size[0] % 2 == 1
        padding = ((kernel_size[0] - 1) // 2, 0)

        self.gcn = ConvTemporalGraphical(in_channels, out_channels,
                                         kernel_size[1])
        '''
        self.tcn = nn.Sequential(
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                (stride, 1),
                padding,
            ),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(dropout, inplace=True),
        )'''
        self.tcn = nn.Sequential(
            nn.GroupNorm(8, out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(
                out_channels,
                out_channels,
                (kernel_size[0], 1),
                (stride, 1),
                padding,
            ),
            nn.GroupNorm(8, out_channels),
            nn.Dropout(dropout, inplace=True),
        )

        if not residual:
            self.residual = lambda x: 0

        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x

        else:
            self.residual = nn.Sequential(
                nn.Conv2d(
                    in_channels,
                    out_channels,
                    kernel_size=1,
                    stride=(stride, 1)),
                nn.BatchNorm2d(out_channels),
            )

        self.relu = nn.ReLU(inplace=True)

    def forward(self, x, A):

        res = self.residual(x)
        x, A = self.gcn(x, A)
        x = self.tcn(x) + res

        return self.relu(x), A


######################################################################

class Model(nn.Module):
    r"""Spatial temporal graph convolutional networks.
    Args:
        in_channels (int): Number of channels in the input data
        num_class (int): Number of classes for the classification task
        graph_args (dict): The arguments for building the graph
        edge_importance_weighting (bool): If ``True``, adds a learnable
            importance weighting to the edges of the graph
        **kwargs (optional): Other parameters for graph convolution units
    Shape:
        - Input: :math:`(N, in_channels, T_{in}, V_{in}, M_{in})`
        - Output: :math:`(N, num_class)` where
            :math:`N` is a batch size,
            :math:`T_{in}` is a length of input sequence,
            :math:`V_{in}` is the number of graph nodes,
            :math:`M_{in}` is the number of instance in a frame.
    """

    def __init__(self, in_channels, num_class, A,
                 edge_importance_weighting, dropout):
        super().__init__()

        self.register_buffer('A', A)

        # build networks
        spatial_kernel_size = A.size(0)
        temporal_kernel_size = 9
        kernel_size = (temporal_kernel_size, spatial_kernel_size)
        self.data_bn = nn.BatchNorm1d(in_channels * A.size(1))

        # kwargs0 = {k: v for k, v in kwargs.items() if k != 'dropout'}
        channels = [64, 64, 64, 128, 128, 256]

        self.st_gcn_networks = nn.ModuleList((
            st_gcn(in_channels, channels[0], kernel_size, 1, dropout=0.1, residual=False),
            st_gcn(channels[0], channels[1], kernel_size, 1, dropout=0.2),
            st_gcn(channels[1], channels[2], kernel_size, 1, dropout=0.3),
            st_gcn(channels[2], channels[3], kernel_size, 2, dropout=0.3),
            st_gcn(channels[3], channels[4], kernel_size, 2, dropout=0.3),
            st_gcn(channels[4], channels[5], kernel_size, 2, dropout=0.3),
        ))

        # initialize parameters for edge importance weighting
        if edge_importance_weighting:
            self.edge_importance = nn.ParameterList([
                nn.Parameter(torch.ones(self.A.size()))
                for i in self.st_gcn_networks
            ])
        else:
            self.edge_importance = [1] * len(self.st_gcn_networks)

        last_channels = channels[-1]
        self.fcn = nn.Conv2d(last_channels, num_class, kernel_size=1)

    def forward(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        # global pooling
        x = F.avg_pool2d(x, x.size()[2:])
        x = x.view(N, M, -1, 1, 1).mean(dim=1)

        # prediction
        x = self.fcn(x)
        x = x.view(x.size(0), -1)

        return x

    def extract_feature(self, x):

        # data normalization
        N, C, T, V, M = x.size()
        x = x.permute(0, 4, 3, 1, 2).contiguous()
        x = x.view(N * M, V * C, T)
        x = self.data_bn(x)
        x = x.view(N, M, V, C, T)
        x = x.permute(0, 1, 3, 4, 2).contiguous()
        x = x.view(N * M, C, T, V)

        # forwad
        for gcn, importance in zip(self.st_gcn_networks, self.edge_importance):
            x, _ = gcn(x, self.A * importance)

        _, c, t, v = x.size()
        feature = x.view(N, M, c, t, v).permute(0, 2, 3, 4, 1)

        # prediction
        x = self.fcn(x)
        output = x.view(N, M, -1, t, v).permute(0, 2, 3, 4, 1)

        return output, feature


#########################################################################

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv1d') != -1 or classname.find('Conv2d') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias.data, 0)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


#########################################################################

def accuracy(output, labels):
    preds = output.argmax(dim=1)
    correct = (preds == labels).float().sum()
    return correct / labels.size(0)


##########################################################################


model = Model(in_channels=3, num_class=num_class, A=A,
              edge_importance_weighting=True, dropout=dropout).to(device)
model.apply(weights_init)

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2.0, reduction="mean"):
        super().__init__()
        # alpha: class-wise weight (tensor of shape [num_class]) or scalar
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, logits, targets):
        # CE per-sample
        ce_loss = F.cross_entropy(logits, targets, reduction="none", weight=self.alpha)
        # p_t = exp(-CE)
        pt = torch.exp(-ce_loss)
        # focal factor
        loss = (1 - pt) ** self.gamma * ce_loss

        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "sum":
            return loss.sum()
        return loss


#criterion = nn.CrossEntropyLoss()
criterion = FocalLoss(alpha=class_weights, gamma=2.0)


optimizer = optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=1e-4
)

scaler = GradScaler(device if device.type == "cuda" else "cpu")

'''
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.1, patience=5
)
'''
steps_per_epoch = len(train_loader)
scheduler = OneCycleLR(
    optimizer,
    max_lr=3e-3,
    total_steps=num_epoch * steps_per_epoch,
    pct_start=0.2,
    anneal_strategy='cos', # cosine annealing
    div_factor=25.0, # initial_lr = max_lr / div_factor
    final_div_factor=1e4, # min_lr = initial_lr / final_div_factor
)

if resume_training:
    print(f"==> Resuming training from epoch {resume_epoch} using saved state...")
    checkpoint = torch.load(state_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    resume_epoch = checkpoint['epoch']

#########################################################
# train / val / test 루프

def train_one_epoch(epoch, history_train):
    model.train()
    loss_avg = 0.0
    acc_avg = 0.0
    count = 0

    accumulation_steps = 4

    optimizer.zero_grad(set_to_none=True)

    for step, (features_batch, labels_batch) in enumerate(train_loader):
        features_batch = features_batch.to(device, non_blocking=True)
        labels_batch = labels_batch.to(device, non_blocking=True)

        with autocast("cuda" if device.type=="cuda" else "cpu"):
            output = model(features_batch)
            loss_train = criterion(output, labels_batch)
            loss_train = loss_train / accumulation_steps

        # gradient 누적
        scaler.scale(loss_train).backward()

        # accumulation_steps마다 optimizer step
        if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_loader):
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            scaler.step(optimizer)
            scaler.update()
            scheduler.step() # ✅ OneCycleLR은 매 step 마다 호출해야 함
            optimizer.zero_grad(set_to_none=True)


        loss_avg += loss_train.item() * accumulation_steps
        acc_avg += accuracy(output, labels_batch).item()
        count += 1

    loss_avg /= count
    acc_avg /= count
    lr_current = optimizer.param_groups[0]['lr']

    print(f"[Train] Epoch {epoch:03d} | Loss: {loss_avg:.4f} | Acc: {acc_avg:.4f} | LR: {lr_current:.5f}")

    new_row = pd.DataFrame([{
        'epoch': epoch,
        'loss': loss_avg,
        'acc': acc_avg
    }])
    history_train = pd.concat([history_train, new_row], ignore_index=True)
    history_train.to_csv(os.path.join(resultsFolder, 'train_' + resultsFile), index=False)
    return history_train

def evaluate(epoch, history, loader, split_name="Val"):
    model.eval()
    loss_avg = 0.0
    acc_avg = 0.0
    count = 0

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for features_batch, labels_batch in loader:
            features_batch = features_batch.to(device, non_blocking=True)
            labels_batch = labels_batch.to(device, non_blocking=True)

            with autocast("cuda" if device.type=="cuda" else "cpu"):
                output = model(features_batch)
                loss_val = criterion(output, labels_batch)

            loss_avg += loss_val.item()
            acc_avg += accuracy(output, labels_batch).item()
            count += 1
            preds = output.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels_batch.cpu())


    loss_avg /= count
    acc_avg /= count

    all_preds = torch.cat(all_preds).numpy()
    all_labels = torch.cat(all_labels).numpy()
    macro_f1 = f1_score(all_labels, all_preds, average='macro')

    print(f"[{split_name:4}] Epoch {epoch:03d} | Loss: {loss_avg:.4f} | "
          f"Acc: {acc_avg:.4f} | MacroF1: {macro_f1:.4f}")

    new_row = pd.DataFrame([{
        'epoch': epoch,
        'loss': loss_avg,
        'acc': acc_avg
    }])
    history = pd.concat([history, new_row], ignore_index=True)

    filename_prefix = split_name.lower()  # 'val', 'test'
    history.to_csv(os.path.join(resultsFolder, f'{filename_prefix}_' + resultsFile), index=False)
    return history, acc_avg, macro_f1


#####################################################################
# 클래스별 성능 함수
def evaluate_per_class(model, loader, split_name="Test"):
    model.eval()
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for features_batch, labels_batch in loader:
            features_batch = features_batch.to(device, non_blocking=True)
            labels_batch = labels_batch.to(device, non_blocking=True)

            # AMP 그대로 사용
            with autocast("cuda" if device.type=="cuda" else "cpu"):
                output = model(features_batch)

            preds = output.argmax(dim=1)
            all_preds.append(preds.cpu())
            all_labels.append(labels_batch.cpu())

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    # confusion matrix
    cm = confusion_matrix(
        all_labels.numpy(),
        all_preds.numpy(),
        labels=list(range(num_class))  # 0 ~ num_class-1
    )
    print(f"\n[{split_name}] Confusion Matrix (rows=true, cols=pred):\n{cm}")

    # 클래스별 정밀도/재현율/F1, accuracy
    print(f"\n[{split_name}] Classification Report:")
    print(classification_report(
        all_labels.numpy(),
        all_preds.numpy(),
        labels=list(range(num_class)),
        digits=4
    ))

######################################################################

# 메인 학습: train + val, best model 저장 후 마지막에 test

if resume_training:
    history_train = pd.read_csv(os.path.join(resultsFolder, 'train_' + resultsFile))
    history_val   = pd.read_csv(os.path.join(resultsFolder, 'val_' + resultsFile))
    history_test  = pd.DataFrame({'epoch': [], 'loss': [], 'acc': []})
else:
    history_train = pd.DataFrame({'epoch': [], 'loss': [], 'acc': []})
    history_val   = pd.DataFrame({'epoch': [], 'loss': [], 'acc': []})
    history_test  = pd.DataFrame({'epoch': [], 'loss': [], 'acc': []})


best_val_f1 = 0.0

# === Early Stopping 설정 ===
early_stop_patience = 10
no_improve_count = 0

start_epoch = resume_epoch + 1 if resume_training else 1

for epoch in tqdm(range(start_epoch, num_epoch + 1)):
    history_train = train_one_epoch(epoch, history_train)
    history_val, val_acc, val_macro_f1 = evaluate(epoch, history_val, val_loader, split_name="Val")

    '''
    prev_lr = optimizer.param_groups[0]['lr']
    scheduler.step(val_macro_f1)
    new_lr = optimizer.param_groups[0]['lr']

    if new_lr != prev_lr:
        print(f"LR reduced: {prev_lr:.6f} -> {new_lr:.6f} (based on Val MacroF1={val_macro_f1:.4f})")
    '''
    # best model 저장
    if val_macro_f1 > best_val_f1:
        best_val_f1 = val_macro_f1
        no_improve_count = 0
        torch.save({
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'epoch': epoch
    }, state_path)
        print(f"==> Best model updated at epoch {epoch} (Val MacroF1={best_val_f1:.4f})")
    else:
        no_improve_count += 1
        print(f"No improvement for {no_improve_count} epoch(s)")

        if no_improve_count >= early_stop_patience:
            print(f"\nEarly stopping at epoch {epoch} "
                  f"(no improvement for {early_stop_patience} consecutive epochs)")
            break
# 최종 Test: best validation 모델 기준
print("\n==> Testing with best validation model...")
best_model = Model(in_channels=3, num_class=num_class, A=A,
                   edge_importance_weighting=True, dropout=dropout).to(device)
checkpoint = torch.load(state_path, map_location=device)
best_model.load_state_dict(checkpoint['model_state_dict'])

model = best_model
history_test, test_acc, test_macro_f1 = evaluate(
    0, history_test, test_loader, split_name="Test"
)
print(f"\nFinal Test Accuracy: {test_acc:.4f} | Final Test MacroF1: {test_macro_f1:.4f}")
evaluate_per_class(model, test_loader, split_name="Test")