### SETP 0. 환경 구축하기

---

- VTT 모델에 맞는 환경을 구축하기 위하여 필요한 패키지를 설치합니다.

In [None]:
# github에서 데이터 불러오기
# !git clone https://github.com/hwk0702/KBS2024-VTT.git
# %cd ./KBS2024-VTT/src
%cd ../src

# !pip install -r ../requirements.txt|

In [None]:
import os  # 운영 체제와의 상호작용을 위한 모듈
import sys  # 시스템 특정 파라미터와 함수를 위한 모듈
import random  # 난수 생성기를 위한 모듈
import math  # 수학적 함수와 상수를 위한 모듈
import time  # 시간 관련 함수를 위한 모듈
import json  # JSON 데이터 파싱과 생성을 위한 모듈
import yaml  # YAML 파일 파싱과 생성을 위한 모듈
import warnings  # 경고 메시지를 관리하기 위한 모듈
warnings.filterwarnings("ignore")  # 경고 메시지를 무시하도록 설정

import argparse  # 명령행 옵션, 인자 파싱 모듈
from datetime import datetime  # 날짜와 시간 관련 모듈

import numpy as np # 다차원 배열과 연산을 다루는 모듈
import pandas as pd  # 데이터 조작 및 분석을 위한 라이브러리
from sklearn.preprocessing import MinMaxScaler, StandardScaler  # 데이터 전처리를 위한 스케일링 모듈

import matplotlib.pyplot as plt  # 데이터 시각화를 위한 라이브러리
# %matplotlib inline  # 주피터 노트북에서 인라인으로 그래프를 표시

from easydict import EasyDict  # 딕셔너리의 속성을 점 표기법으로 접근할 수 있도록 하는 모듈
from omegaconf import OmegaConf  # 설정 파일 관리를 위한 모듈

import torch  # 딥러닝을 위한 라이브러리
from torch import optim, einsum  # 최적화 및 다차원 배열 연산을 위한 모듈
import torch.nn as nn  # 신경망 구축을 위한 모듈
from torch.utils.data import Dataset  # 데이터셋 구성을 위한 모듈
from torch.nn.utils import weight_norm  # 신경망 가중치 정규화를 위한 모듈

from einops import rearrange, repeat  # 배열 재구성 및 반복을 위한 모듈

from layers.Transformer_Enc import PreNorm, FeedForward  # 사용자 정의 트랜스포머 인코더 모듈

from utils.utils import set_seed, log_setting, version_build  # 유틸리티 함수들
from utils.tools import EarlyStopping, adjust_learning_rate, check_graph  # 학습 조기 종료 및 학습률 조정 도구
from data_provider.waferdataset import get_dataloader  # 데이터 로더 함수
from model import build_model  # 모델 빌드 함수

# torch 버전 및 디바이스를 확인
print("Python version:[%s]." % (sys.version))
print("PyTorch version:[%s]." % (torch.__version__))
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print("device:[%s]." % (device))  # GPU를 사용 가능한 경우 'cuda:0'가 출력되면 GPU를 사용

모델에 필요한 argument들을 설정합니다.

In [None]:
# 설정 인자들을 정의
args = EasyDict({
    # for dataloader 
    'dataname': 'vtt_step45_big_ref',
    'subdataname': None,
    'window_size': 30,
    'slide_size': 1,

    # for training
    'train': True,
    'test': True,
    'resume': None,
    'model': 'VTTSAT',
    'epochs': 20,
    'patience': 5,

    # for loss
    'loss': 'mse',

    # for setting
    'seed': 72,
    'use_gpu':True,
    'gpu': 0,
    'use_multi_gpu': False,
    'devices': '0',
    'configure': './config.yaml',
    'batch_size': 16
})

# 구성 파일을 로드합니다.
with open(args.configure) as f:
    config = OmegaConf.load(f)

# 로드된 구성 내용을 출력합니다.
print(OmegaConf.to_yaml(config, resolve=True))

### STEP 9. XAI

<img src="../image/XAI.png" width="700">

In [None]:
from scipy.special import softmax
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import pyplot as plt
import matplotlib.colors as mcl
%matplotlib inline

from utils.utils import load_model

In [None]:
import h5py
import numpy as np

def load_inference_result(h5_path):
    with h5py.File(h5_path, 'r') as f:
        preds = f['pred'][:]
        actuals = f['actual'][:]
        lotids = f['lotid'][:].astype(str)
        wafer_numbers = f['wafer_number'][:].astype(str)
        step_nums = f['step_num'][:]
        masks = f['mask'][:]
        dist_per_feature = f['dist_per_window_per_features'][:] if 'dist_per_window_per_features' in f else None

    return preds, actuals, dist_per_feature, lotids, wafer_numbers, step_nums, masks

# main.py --test 로 실행 후, 추출되는 파일 필요 (inference_result_with_metadata.h5)
h5_path = './logs/vtt_all_step/VTTSAT/version5/results/inference_result_with_metadata.h5'
preds, actuals, dist_per_feature, lotids, wafer_numbers, step_nums, masks = load_inference_result(h5_path)

masks = masks.reshape(16808, 350)

feature_names = [
    'APC_Position',
    'APC_Pressure',
    'Gas1_Monitor',
    'Gas6_Monitor',
    'Mat_Irms',
    'Mat_Vrms',
    'Mat_VC1_Position',
    'Mat_VC2_Position',
    'SourcePwr_Read',
    'Temp',
    'Wall_Temp_Monitor'
]

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
import seaborn as sns

sns.set_style("darkgrid")

def visualize_overlay_all_wafers_subplots(
    preds, actuals, lotids, wafer_numbers, feature_names, masks,
    output_path='overlay_all_wafers_subplots.png', num_wafer_to_plot=1000, seed=42
):
    """
    12개 센서를 (3,4) subplot으로 시각화. 각 subplot에는 여러 웨이퍼가 겹쳐져 있음.
    (웨이퍼 하나당 시계열 전체를 하나의 샘플로 사용하는 경우에 맞게 수정됨)
    """
    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))
    unique_wafer_ids = np.unique(unique_ids)

    np.random.seed(seed)
    selected_indices = np.random.choice(len(unique_ids), size=min(num_wafer_to_plot, len(unique_ids)), replace=False)

    T = preds.shape[1]
    D = preds.shape[2]

    # 시각화
    n_rows, n_cols = 3, 4
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(4 * n_cols, 3 * n_rows), sharex=False)
    axes = axes.flatten()

    for d in range(D):
        ax = axes[d]

        for i, idx in enumerate(selected_indices):
            actual_seq = actuals[idx, :, d]
            pred_seq = preds[idx, :, d]
            
            mask_seq = masks[idx]
            time_indices = np.where(mask_seq ==1)[0]
            actual_seq = actual_seq[time_indices]
            pred_seq = pred_seq[time_indices]

            label_actual = 'Actual' if i == 0 else None
            label_pred = 'Predicted' if i == 0 else None

            ax.plot(actual_seq, color='blue', alpha=0.2, linewidth=0.8, label=label_actual, zorder=0)
            ax.plot(pred_seq, color='magenta', alpha=0.2, linewidth=0.8, label=label_pred, zorder=1)

        ax.set_title(feature_names[d], fontsize=10, fontweight='bold')
        ax.yaxis.set_major_locator(MultipleLocator(0.1))
        # ax.grid(True, linestyle='-', linewidth=1.0, color='gray', alpha=0.4)

        if d == 0:
            ax.legend(fontsize=8)

    for j in range(D, len(axes)):
        fig.delaxes(axes[j])

    fig.suptitle(f'Overlay of {len(selected_indices)} Wafers per Sensor', fontsize=16)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.savefig(output_path, bbox_inches='tight')
    plt.close(fig)
    print(f"✅ 저장 완료: {output_path}")


In [None]:
visualize_overlay_all_wafers_subplots(
    preds=preds,
    actuals=actuals,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    feature_names=feature_names,
    masks=masks,
    output_path='./logs/vtt_all_step/VTTSAT/version5/plots/overlay_all_wafers_subplots.png',
    num_wafer_to_plot=16808
)

## step 별로

In [None]:
import numpy as np
import pandas as pd
def compute_vtt_score_per_stepnum(preds, actuals, masks, lotids, wafer_numbers, step_nums, feature_names):
    """
    step_num별로 VTT score(MAE, masking 적용)를 계산하고 하나의 DataFrame으로 정리 (전체 시계열 1개 per wafer)

    Returns:
        pd.DataFrame: columns = [step, unique_id, sensor1, ..., sensorN]
    """
    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))
    step_nums = np.array(step_nums)

    D = preds.shape[2]
    result_records = []

    for step in sorted(np.unique(step_nums)):
        # 🔹 해당 step만 추출
        step_mask = step_nums == step
        step_preds = preds[step_mask]
        step_actuals = actuals[step_mask]
        step_masks = masks[step_mask]
        step_ids = unique_ids[step_mask]

        for i, uid in enumerate(step_ids):
            pred = step_preds[i]        # (T, D)
            actual = step_actuals[i]    # (T, D)
            mask = step_masks[i]        # (T,)

            valid_idx = mask == 1
            if not np.any(valid_idx):
                continue

            mae_per_feature = -np.sum(np.abs(pred[valid_idx] - actual[valid_idx]), axis=0)  # (D,)
            record = {'step': int(step), 'unique_id': uid}
            record.update({feature_names[d]: mae_per_feature[d] for d in range(D)})
            result_records.append(record)

    return pd.DataFrame.from_records(result_records)


vtt_stepwise_scores_df = compute_vtt_score_per_stepnum(
    preds=preds,
    actuals=actuals,
    masks=masks,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    step_nums=step_nums,
    feature_names=feature_names
)

vtt_stepwise_scores_df.to_csv('./logs/vtt_all_step/VTTSAT/version5/plots/vtt_scores_per_wafer.csv', index=False)

## step 구분없이 스코어 계산

In [None]:
def compute_vtt_score_all_wafer(preds, actuals, masks, lotids, wafer_numbers, feature_names):
    """
    VTT score(MAE, masking 적용)를 웨이퍼 단위로 계산
    step_num 구분 없이 전체에 대해 계산

    Returns:
        pd.DataFrame: columns = [unique_id, sensor1, ..., sensorN]
    """
    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))

    D = preds.shape[2]
    result_records = []

    for i, uid in enumerate(unique_ids):
        pred = preds[i]        # (T, D)
        actual = actuals[i]    # (T, D)
        mask = masks[i]        # (T,)

        valid_idx = mask == 1
        if not np.any(valid_idx):
            continue

        mae_per_feature = -np.sum(np.abs(pred[valid_idx] - actual[valid_idx]), axis=0)  # (D,)
        record = {'unique_id': uid}
        record.update({feature_names[d]: mae_per_feature[d] for d in range(D)})
        result_records.append(record)

    return pd.DataFrame.from_records(result_records)


vtt_all_scores_df = compute_vtt_score_all_wafer(
    preds=preds,
    actuals=actuals,
    masks=masks,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    feature_names=feature_names
)
vtt_all_scores_df.to_csv('./logs/vtt_all_step/VTTSAT/version5/results/vtt_scores_per_wafer_allstep.csv', index=False)

# XAI

In [None]:
# 설정 인자들을 정의
args = EasyDict({
    # for dataloader 
    'dataname': 'vtt_all_step',
    'subdataname': None,
    'window_size': 30,
    'slide_size': 1,

    # for training
    'train': True,
    'test': True,
    'resume': 5,
    'model': 'VTTSAT',
    'epochs': 20,
    'patience': 5,

    # for loss
    'loss': 'mse',

    # for setting
    'seed': 72,
    'use_gpu':True,
    'gpu': 0,
    'use_multi_gpu': False,
    'devices': '0',
    'configure': './config.yaml',
    'batch_size': 16
})

# 구성 파일을 로드합니다.
with open(args.configure) as f:
    config = OmegaConf.load(f)

# 로드된 구성 내용을 출력합니다.
print(OmegaConf.to_yaml(config, resolve=True))

In [None]:
# 로그 디렉토리와 저장 디렉토리를 설정
logdir = os.path.join(config.log_dir, f'{args.dataname}/{args.model}')
savedir = version_build(logdir=logdir, is_train=args.train, resume=args.resume)

# GPU 사용 설정을 구성
if args.use_gpu and args.use_multi_gpu:
    # GPU 디바이스 목록을 설정
    args.devices = args.devices.replace(' ', '')
    args.device_ids = list(map(int, args.devices.split(',')))
    args.gpu = args.device_ids[0]  # 첫 번째 GPU를 기본 GPU로 설정

# 랜덤 시드를 설정
set_seed(args.seed)

# 모델 파라미터를 설정
model_params = config[args.model]

# # 모델 파라미터를 수정(실험 세팅 가볍게)
# model_params.n_layer = 1
# model_params.n_head = 2

### SETP 1. 데이터 준비하기

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

import torch.nn.functional as F
from sklearn.preprocessing import MinMaxScaler, StandardScaler
import numpy as np
from tqdm import tqdm
import os
import h5py

"""
GDN에 적합하도록 WaferDataset을 변환한 데이터셋 클래스
dataset_preprocessing.ipynb에서 생성하여, h5 파일 이용
"""
class WaferDataset(Dataset):
    def __init__(self, file_list, model_type='reconstruction', max_len = 350):
        self.file_list = file_list
        self.model_type = model_type
        self.max_seq_len = max_len

        self.data, self.labels = [], []
        self.masks = []
        self.lengths = []
        self.lotids = []
        self.wafer_numbers = []
        self.step_nums = []

        for file_path in tqdm(file_list, desc="Loading and merging data"):
            with h5py.File(file_path, 'r') as f:
                raw_data = f['data'][:].astype(float)
                T_i = raw_data.shape[0]
                C = raw_data.shape[1]
                # D = raw_data.shape[2]
                # print(T_i, C)
                # padding
                if T_i < self.max_seq_len:
                    padded = np.zeros((self.max_seq_len, C))
                    padded[:T_i] = raw_data # 뒷 부분은 zero padding
                    mask = np.zeros((self.max_seq_len,), dtype=np.float32)
                    mask[:T_i] = 1.0
                else : 
                    padded = raw_data[:self.max_seq_len]
                    mask = np.ones((self.max_seq_len,), dtype=np.float32)
                    
                self.data.append(torch.tensor(padded, dtype=torch.float32))
                self.masks.append(torch.tensor(mask, dtype=torch.float32))
                
                # self.labels.append(f['labels'][:])
                self.lengths.append(min(T_i, self.max_seq_len))
                
                self.lotids.extend(f['lotids'][:].astype(str))
                self.wafer_numbers.extend(f['wafer_numbers'][:].astype(str))
                self.step_nums.extend(f['step_num'][:])

        # self.all_data = np.concatenate(all_data, axis=0)
        # self.all_next_steps = np.concatenate(all_next_steps, axis=0)
        # self.all_labels = np.concatenate(all_labels, axis=0)
        # self.n_sensor = self.data.shape[2]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = {
            'given': self.data[idx],           # (max_seq_len, C)
            'mask': self.masks[idx],           # (max_seq_len,)
            'length': self.lengths[idx],       # int
            # 'label': self.labels[idx],         # int
            'lotid': self.lotids[idx],
            'wafer_number': self.wafer_numbers[idx],
            'step_num': self.step_nums[idx],
        }

        if self.model_type == 'reconstruction':
            item['answer'] = self.data[idx]
        # elif self.model_type == 'prediction':
        #     item['answer'] = self.data[idx][1:]  # 예: 다음 스텝 예측 등
        return item

def get_dataloader(data_info, loader_params: dict):
    """
    GDN 모델에 적합한 DataLoader 생성 함수 (train/val split 포함)

    Args:
        data_info (dict): 'train_dir' 키 포함
        loader_params (dict): batch_size, use_val 등 포함

    Returns:
        tuple: (train_loader, val_loader or None, test_loader)
    """
    # 전체 train 파일 리스트
    train_files = sorted([
        os.path.join(data_info['train_dir'], f)
        for f in os.listdir(data_info['train_dir'])
        if f.endswith(".h5")
    ])

    # validation 포함 여부
    if loader_params['use_val']:
        val_files = sorted([
        os.path.join(data_info['val_dir'], f)
        for f in os.listdir(data_info['val_dir'])
        if f.endswith(".h5")
    ])

        val_dataset = WaferDataset(val_files)
        val_loader = DataLoader(val_dataset,
                                batch_size=loader_params['batch_size'],
                                shuffle=False,
                                num_workers=0,
                                pin_memory=True,
                                drop_last=False)
    else:
        val_loader = None

    # test 파일 리스트 (필요시 별도 dir로 교체 가능)
    test_files = sorted([
        os.path.join(data_info['test_dir'], f)
        for f in os.listdir(data_info['test_dir'])
        if f.endswith(".h5")
    ])

    # Dataset & DataLoader
    train_dataset = WaferDataset(train_files)
    test_dataset = WaferDataset(test_files)

    train_loader = DataLoader(train_dataset,
                              batch_size=loader_params['batch_size'],
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True,
                              drop_last=False)

    test_loader = DataLoader(test_dataset,
                             batch_size=loader_params['batch_size'],
                             shuffle=False,
                             num_workers=0,
                             pin_memory=True,
                             drop_last=False)

    return train_loader, val_loader, test_loader


In [None]:
data_info = config[args.dataname]

trainloader, validloader, testloader = get_dataloader(data_info = data_info,
                                                        loader_params = config['loader_params'])        


# 모델 code 실행

In [None]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEmbedding, self).__init__()
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model).float()
        pe.require_grad = False

        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = (torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)).exp()

        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return self.pe[:, :x.size(1)]
    
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation=1, groups=1):
        super().__init__()
        padding = (kernel_size - 1) * dilation
        self.conv = nn.Conv1d(
            in_channels, out_channels, kernel_size,
            padding=padding,
            dilation=dilation,
            groups=groups
        )

    def forward(self, x):
        out = self.conv(x)
        out = out[:, :, : -self.conv.padding[0]]
        return out

In [None]:
import torch
import torch.nn as nn
from torch import einsum
from einops import rearrange, repeat
import math
from math import sqrt
import numpy as np

class VariableAttention(nn.Module):
    def __init__(
            self,
            dim,
            heads=8,
            dim_head=16,
            dropout=0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x, use_attn):
        b, t, f, d = x.shape
        x = rearrange(x, 'b t f d -> (b t) f d')
        h = self.heads
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda m: rearrange(m, 'b n (h d) -> b h n d', h=h), (q, k, v))
        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = sim.softmax(dim=-1)
        if use_attn:
            weights = attn
        else:
            weights = None
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=h)
        out = self.to_out(out)
        out = rearrange(out, '(b t) f d -> b t f d', t=t)
        return out, weights

class TemporalAttention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=16, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.to_out = nn.Linear(inner_dim, dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, use_attn=False, mask=None):
        b, t, f, d = x.shape
        x = rearrange(x, 'b t f d -> (b f) t d')  # across time
        h = self.heads

        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda m: rearrange(m, 'b n (h d) -> b h n d', h=h), (q, k, v))

        sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale  # (b*f, h, t, t)

        if mask is not None:
            # mask: (B, T) → (b*f, 1, 1, T)
            mask = repeat(mask, 'b t -> (b f) t', f=f)
            mask = rearrange(mask, 'b j -> b 1 1 j')
            sim = sim.masked_fill(mask == 0, -1e9)

        attn = sim.softmax(dim=-1)
        weights = attn if use_attn else None
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)', h=h)
        out = self.to_out(out)
        out = rearrange(out, '(b f) t d -> b t f d', f=f)
        return out, weights

class VariableTemporalAttention(nn.Module):
    def __init__(
            self,
            dim,
            heads=8,
            dim_head=16,
            dropout=0.
    ):
        super().__init__()
        inner_dim = dim_head * heads
        self.heads = heads
        self.scale = dim_head ** -0.5

        self.variable_to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.variable_to_out = nn.Linear(inner_dim, dim)
        self.variable_dropout = nn.Dropout(dropout)

        self.temporal_to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
        self.temporal_to_out = nn.Linear(inner_dim, dim)
        self.temporal_dropout = nn.Dropout(dropout)

        self.linear = nn.Linear(dim * 2, dim)

    def forward(self, x, use_attn):
        b, t, f, d = x.shape

        vx = rearrange(x, 'b t f d -> (b t) f d')
        h = self.heads
        vq, vk, vv = self.variable_to_qkv(vx).chunk(3, dim=-1)
        vq, vk, vv = map(lambda m: rearrange(m, 'b n (h d) -> b h n d', h=h), (vq, vk, vv))
        sim = einsum('b h i d, b h j d -> b h i j', vq, vk) * self.scale

        vattn = sim.softmax(dim=-1)
        if use_attn:
            vweights = vattn
        else:
            vweights = None
        vattn = self.variable_dropout(vattn)

        vout = einsum('b h i j, b h j d -> b h i d', vattn, vv)
        vout = rearrange(vout, 'b h n d -> b n (h d)', h=h)
        vout = self.variable_to_out(vout)
        vout = rearrange(vout, '(b t) f d -> b t f d', t=t)

        tx = rearrange(x, 'b t f d -> (b f) t d')
        tq, tk, tv = self.temporal_to_qkv(tx).chunk(3, dim=-1)
        tq, tk, tv = map(lambda m: rearrange(m, 'b n (h d) -> b h n d', h=h), (tq, tk, tv))
        sim = einsum('b h i d, b h j d -> b h i j', tq, tk) * self.scale

        tattn = sim.softmax(dim=-1)
        if use_attn:
            tweights = tattn
        else:
            tweights = None
        tattn = self.temporal_dropout(tattn)

        tout = einsum('b h i j, b h j d -> b h i d', tattn, tv)
        tout = rearrange(tout, 'b h n d -> b n (h d)', h=h)
        tout = self.temporal_to_out(tout)
        tout = rearrange(tout, '(b f) t d -> b t f d', f=f)

        out = torch.cat([vout, tout], dim=-1)
        out = self.linear(out)
        return out, vweights, tweights

#### VTT-SAT

In [None]:
class VTTSAT(nn.Module):
    def __init__(self, params):
        super(VTTSAT, self).__init__()
        self.feature_num = params['feature_num']
        self.num_transformer_blocks = params['n_layer']
        self.num_heads = params['n_head']
        self.embedding_dims = params['hidden_size']
        self.attn_dropout = params['attn_pdrop']
        self.ff_dropout = params['resid_pdrop']
        self.time_emb = params['time_emb']

        # position 인코딩과 value 인코딩 레이어 정의
        self.position_embedding = PositionalEmbedding(d_model=self.embedding_dims)
        self.value_embedding = nn.Linear(self.time_emb, self.embedding_dims)

        # CausalConv1d 레이어 정의
        kernel_sizes = [4, 8, 16]
        dilations = [1, 2, 3]
        self.causal_convs = nn.ModuleList([
            CausalConv1d(self.feature_num, self.feature_num, kernel_size=k, dilation=d, groups=self.feature_num)
            for k, d in zip(kernel_sizes, dilations)
        ])

        # transformer 레이어 정의          
        self.transformer_layers = nn.ModuleList([
            nn.ModuleList([
                PreNorm(self.embedding_dims, VariableAttention(self.embedding_dims, 
                                                               heads=self.num_heads, 
                                                               dim_head=self.embedding_dims, 
                                                               dropout=self.attn_dropout)),
                PreNorm(self.embedding_dims, TemporalAttention(self.embedding_dims, 
                                                               heads=self.num_heads, 
                                                               dim_head=self.embedding_dims, 
                                                               dropout=self.attn_dropout)),
                PreNorm(self.embedding_dims, FeedForward(self.embedding_dims, dropout=self.ff_dropout))
            ]) for _ in range(self.num_transformer_blocks)
        ])

        self.dropout = nn.Dropout(self.ff_dropout)
        self.mlp_head = nn.Linear(self.feature_num*self.embedding_dims, self.feature_num)

    def forward(self, x, use_attn=False, mask=None):
        variable_attn_weights = []
        temporal_attn_weights = []
        # print(x.shape)
        b, w, f = x.shape
        x = rearrange(x, 'b w f -> b f w')
        conv1 = self.causal_conv1(x)
        conv2 = self.causal_conv2(x)
        conv3 = self.causal_conv3(x)

        x = torch.stack([x, conv1, conv2, conv3], dim=-1)
        x = rearrange(x, 'b f w d -> b w f d')
        x = self.value_embedding(x)

        position_emb = self.position_embedding(x)
        position_emb = repeat(position_emb, 'b t d -> b t f d', f=f)
        x += position_emb
        x = self.dropout(x)
        h = x

        for vattn, tattn, ff in self.transformer_layers:
            x, weights = vattn(h, use_attn=use_attn)
            h = x + h
            variable_attn_weights.append(weights)
            x, weights = tattn(h, use_attn=use_attn, mask = mask)
            h = x + h
            temporal_attn_weights.append(weights)
            x = ff(x)
            h = x + h

        h = rearrange(h, 'b w f d -> b w (f d)')
        h = self.mlp_head(h)
        
        # Convert lists to tensors
        if use_attn:
            variable_attn_weights = torch.stack(variable_attn_weights, dim=0)
            temporal_attn_weights = torch.stack(temporal_attn_weights, dim=0)

        return h, [variable_attn_weights, temporal_attn_weights]

#### VTT-PAT

In [None]:
class VTTPAT(nn.Module):
    def __init__(self, params):
        super(VTTPAT, self).__init__()
        self.feature_num = params['feature_num']
        self.num_transformer_blocks = params['n_layer']
        self.num_heads = params['n_head']
        self.embedding_dims = params['hidden_size']
        self.attn_dropout = params['attn_pdrop']
        self.ff_dropout = params['resid_pdrop']
        self.time_emb = params['time_emb']

        # position 인코딩과 value 인코딩 레이어 정의
        self.position_embedding = PositionalEmbedding(d_model=self.embedding_dims)
        self.value_embedding = nn.Linear(self.time_emb, self.embedding_dims)

        # CausalConv1d 레이어 정의
        kernel_sizes = [4, 8, 16]
        dilations = [1, 2, 3]
        self.causal_convs = nn.ModuleList([
            CausalConv1d(self.feature_num, self.feature_num, kernel_size=k, dilation=d, groups=self.feature_num)
            for k, d in zip(kernel_sizes, dilations)
        ])

        # transformer 레이어 정의  
        self.transformer_layers = nn.ModuleList([
            PreNorm(self.embedding_dims, 
                    VariableTemporalAttention(self.embedding_dims, 
                                              heads=self.num_heads, 
                                              dim_head=self.embedding_dims, 
                                              dropout=self.attn_dropout))
            for _ in range(self.num_transformer_blocks)
        ])

        self.dropout = nn.Dropout(self.ff_dropout)
        self.mlp_head = nn.Linear(self.feature_num*self.embedding_dims, self.feature_num)


    def forward(self, x, use_attn=False):
        variable_attn_weights = []
        temporal_attn_weights = []
        b, w, f = x.shape

        # 입력을 재구성하여 CausalConv1d 적용
        x = rearrange(x, 'b w f -> b f w')
        conv_outputs = [conv(x) for conv in self.causal_convs]
        x = torch.stack([x, conv1, conv2, conv3], dim=-1)
        x = rearrange(x, 'b f w d -> b w f d')
        x = self.value_embedding(x)

        # position 인코딩 추가
        position_emb = self.position_embedding(x)
        position_emb = repeat(position_emb, 'b t d -> b t f d', f=f)
        x += position_emb
        x = self.dropout(x)
        h = x
        
        # transformer 레이어 순차적으로 적용
        for attn in self.transformer_layers:
            x, vweights, tweights = attn(h, use_attn=use_attn)
            h = x + h
            variable_attn_weights.append(vweights)
            temporal_attn_weights.append(tweights)

        # 출력 레이어 적용
        h = rearrange(h, 'b w f d -> b w (f d)')
        h = self.mlp_head(h)
        return h, [variable_attn_weights, temporal_attn_weights]

### STEP 6. Build model 모듈

In [None]:
from models import VTTPAT, VTTSAT
from utils.tools import EarlyStopping, adjust_learning_rate, check_graph
from utils.utils import load_model, CheckPoint
from utils.metrics import bf_search

import torch
import torch.nn as nn
from torch import optim
from einops import rearrange

import os
import time

import matplotlib.pyplot as plt
import warnings
import numpy as np
import json
import pdb
from tqdm import tqdm
import h5py
warnings.filterwarnings('ignore')

class build_model():
    """
    Build and train or test a model

    Parameters
    ----------
    args : dict
        arguments
    params : dict
        model's hyper parameters
    savedir : str
        save directory

    Attributes
    ------------
    args : dict
        arguments
    params : dict
        model's hyper parameters
    savedir : str
        save directory
    device : torch.device
        device
    model : nn.module
        model

    Methods
    -------
    _build_model()
        Select the model you want to train or test and build it
        + multi gpu setting
    _acquire_device()
        check gpu usage
    _select_optimizer()
        select the optimizer (default AdamW)
    _select_criterion()
        select the criterion (default MSELoss)
    valid()

    train()

    test()


    """

    def __init__(self, args, params, savedir):
        super(build_model, self).__init__()
        self.args = args
        self.params = params
        self.savedir = savedir
        self.device = self._acquire_device()
        self.model = self._build_model().to(self.device)

    def _build_model(self):
        model_dict = {
            'VTTSAT': VTTSAT,
            'VTTPAT': VTTPAT,
        }
        
        model = model_dict[self.args.model].Model(self.params).float()
        
        if self.args.use_multi_gpu and self.args.use_gpu:
            print('using multi-gpu')
            model = nn.DataParallel(model, device_ids=self.args.device_ids)
        return model

    def _acquire_device(self):
        if self.args.use_gpu:
            os.environ["CUDA_VISIBLE_DEVICES"] = str(
                self.args.gpu) if not self.args.use_multi_gpu else self.args.devices
            device = torch.device('cuda:{}'.format(self.args.gpu))
            print('Use GPU: cuda:{}'.format(self.args.gpu))
        else:
            device = torch.device('cpu')
            print('Use CPU')
        return device

    def _select_optimizer(self):
        if self.params.optim == 'adamw':
            model_optim = optim.AdamW(self.model.parameters(), lr=self.params.lr)
        elif self.params.optim == 'adam':
            model_optim = optim.Adam(self.model.parameters(), lr=self.params.lr)
        elif self.params.optim == 'sgd':
            model_optim = optim.SGD(self.model.parameters(), lr=self.params.lr)
        return model_optim

    def _select_criterion(self):
        if self.args.loss == 'mse':
            criterion = nn.MSELoss(reduction='none')
        elif self.args.loss == 'mae':
            criterion = nn.L1Loss(reduction='none')
        return criterion
    
    def valid(self, valid_loader, criterion, epoch):
        """
        validation

        Parameters
        ----------
        valid_loader : torch.utils.data.DataLoader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1

        criterion

        Return
        ------
        total_loss

        """
        total_loss = []
        valid_score = []
        self.model.eval()
        with torch.no_grad():               
            for batch_idx, batch in enumerate(valid_loader):
                batch_x = batch['given'].float().to(self.device)
                batch_y = batch['answer'].float().to(self.device)

                if self.args.model in ['VTTSAT', 'VTTPAT']:
                    output, _ = self.model(batch_x)
                    loss = criterion(output, batch_y)
                    valid_score.append(np.mean(loss.cpu().detach().numpy(), axis=2))
                    loss = torch.mean(loss)
                else:
                    output = self.model(batch_x)
                    loss = criterion(output, batch_y)
                    valid_score.append(np.mean(loss.cpu().detach().numpy(), axis=2))
                    loss = torch.mean(loss)
                
                total_loss.append(loss.item())

        total_loss = np.average(total_loss)
        self.model.train()
        return total_loss, valid_score

    def train(self, train_loader, valid_loader, test_loader, use_val=False,alpha=.5, beta=.5):
        """
        training

        Parameters
        ----------
        train_loader : torch.utils.data.DataLoader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
        valid_loader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
        test_loader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
            attack : ndarray
            attack labels

        Return
        ------
        model

        """
        time_now = time.time()
        best_metrics = None
        
        if self.args.resume is not None:
            print(f'resume version{self.args.resume}')
            weights, start_epoch, self.args.lr, best_metrics = load_model(resume=self.args.resume,
                                                                          logdir=self.savedir)
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(weights)
            else:
                self.model.load_state_dict(weights, map_location=self.device)

        # set checkpoint
        ckp = CheckPoint(logdir=self.savedir,
                         last_metrics=best_metrics,
                         metric_type='loss')

        train_steps = len(train_loader)
        early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
        
        model_optim = self._select_optimizer()
        criterion = self._select_criterion()

        history = {'train_loss': [], 'validation_loss': []}
        
        for epoch in range(self.args.epochs):
            iter_count = 0
            train_loss = []
            train_score = []

            self.model.train()
            epoch_time = time.time()
            
            with tqdm(total=len(train_loader), desc=f"Epoch {epoch+1}/{self.args.epochs}") as pbar:
                for batch_idx, batch in enumerate(train_loader):
                    iter_count += 1

                    batch_x = batch['given'].float().to(self.device)
                    batch_y = batch['answer'].float().to(self.device)
                    batch_mask = batch['mask'].float().to(self.device)

                    model_optim.zero_grad()

                    if self.args.model in ['VTTSAT', 'VTTPAT']:
                        output, _ = self.model(batch_x, use_attn=False, mask=batch_mask)
                    else:
                        output = self.model(batch_x)

                    # loss 계산 및 마스킹
                    loss_raw = criterion(output, batch_y)                   # (B, T, D)
                    loss_masked = loss_raw * batch_mask.unsqueeze(-1)      # (B, T, D)
                    loss = loss_masked.sum() / batch_mask.sum()            # scalar

                    # score 저장 (B, T): D축 평균
                    train_score.append(loss_masked.detach().cpu().numpy().mean(axis=2))
                    train_loss.append(loss.item())

                    # backprop
                    loss.backward()
                    model_optim.step()

                    pbar.update(1)

            train_score = np.concatenate(train_score).flatten()
            train_loss_avg = np.average(train_loss)
            
            if use_val:
                valid_loss, valid_score = self.valid(valid_loader, criterion, epoch)
                valid_score = np.concatenate(valid_score).flatten()
            
            # 테스트셋(label이 있는) inference
            # dist, attack = self.inference(test_loader, epoch)

            # result save
            # folder_path = os.path.join(self.savedir, 'results', f'epoch_{epoch}')
            # os.makedirs(folder_path, exist_ok=True)
            # visual = check_graph(dist, attack, piece=4)
            # visual.savefig(os.path.join(folder_path, f'graph.png'))
            # np.save(os.path.join(folder_path, f'dist.npy'), dist)
            # np.save(os.path.join(folder_path, f'attack.npy'), attack)

                print('=' * 100)
                print(f"[Epoch {epoch + 1}] "
                    f"Time: {time.time() - epoch_time:.2f}s | "
                    f"Steps: {train_steps} | "
                    f"Train Loss: {train_loss_avg:.8f} | "
                    f"Val Loss: {valid_loss:.8f}")

                history['validation_loss'].append(valid_loss)

                ckp.check(epoch=epoch + 1, model=self.model, score=valid_loss, lr=model_optim.param_groups[0]['lr'])

                if early_stopping.validate(valid_loss):
                    print("Early stopping")
                    break
            else:
                print('=' * 100)
                print(f"[Epoch {epoch + 1}] "
                  f"Time: {time.time() - epoch_time:.2f}s | "
                  f"Steps: {train_steps} | "
                  f"Train Loss: {train_loss_avg:.8f}")
                
                # train loss 기준 체크
                ckp.check(epoch=epoch + 1, model=self.model, score=train_loss_avg, lr=model_optim.param_groups[0]['lr'])
                
                if early_stopping.validate(train_loss_avg):
                    print("Early stopping")
                    break
                
            history['train_loss'].append(train_loss_avg)
            # adjust_learning_rate(model_optim, epoch + 1, self.params)
        
        best_model_path = os.path.join(self.savedir, f'{ckp.best_epoch}.pth')
        checkpoint = torch.load(best_model_path, weights_only=False)
        
        if isinstance(self.model, nn.DataParallel):
            # self.model.module.load_state_dict(torch.load(best_model_path)['weight'])
            self.model.module.load_state_dict(checkpoint['weight'])
        else:
            # self.model.load_state_dict(torch.load(best_model_path)['weight'])
            self.model.load_state_dict(checkpoint['weight'])

        return history


    def test(self, valid_loader, test_loader, alpha=.5, beta=.5):
        """
        test

        Parameters
        ----------
        test_loader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
            attack : ndarray
            attack labels

        Returns
        -------

        """

        if self.args.resume is not None and self.args.train is not True:
            print(f'resume version{self.args.resume}')
            weights, start_epoch, self.args.lr, best_metrics = load_model(resume=self.args.resume,
                                                                          logdir=self.savedir)
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(weights)
            else:
                self.model.load_state_dict(weights)

        # dist, attack, pred = [], [], []
        dist, pred = [], []
        history = dict()
        criterion = self._select_criterion()

        self.model.eval()
        with torch.no_grad():      
            for batch_idx, batch in enumerate(test_loader):
                batch_x = batch['given'].float().to(self.device)
                batch_y = batch['answer'].float().to(self.device)
                # batch_attack = batch['attack'].reshape(-1, batch['attack'].shape[-1]).numpy()

                if self.args.model in ['VTTSAT', 'VTTPAT']:
                    predictions, _ = self.model(batch_x)
                else:
                    predictions = self.model(batch_x)

                score = criterion(predictions, batch_y).cpu().detach().numpy()
                pred.append(predictions.cpu().detach().numpy())
                dist.append(np.mean(score, axis=2))
                # attack.append(batch_attack)

        dist = np.concatenate(dist).flatten()
        # attack = np.concatenate(attack).flatten()
        pred = np.concatenate(pred)

        # # score
        # scores = dist.copy()
        # K = [0, 1, 2, 3, 4, 5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100]
        # f1_values = []

        # print(f'Threshold start: {np.percentile(scores, 90):.4f} end: {np.percentile(scores, 99):.4f}')

        # for k in K:
        #     scores = dist.copy()
        #     [f1, precision, recall, _, _, _, _, roc_auc, _, _], threshold = bf_search(scores, attack,
        #                                                                         start=np.percentile(scores, 50),
        #                                                                         end=np.percentile(scores, 99),
        #                                                                         step_num=1000,
        #                                                                         K=k,
        #                                                                         verbose=False)

        #     f1_values.append(f1)
        #     print(f"K: {k} precision: {precision:.4f} recall: {recall:.4f} f1: {f1:.4f} AUROC: {roc_auc:.4f}")
        #     history.setdefault(f'precision_{k}', []).append(precision)
        #     history.setdefault(f'recall_{k}', []).append(recall)
        #     history.setdefault(f'f1_{k}', []).append(f1)
        #     history.setdefault(f'roc_auc', []).append(roc_auc)

        # auc = sum(0.5 * (f1_values[i] + f1_values[i + 1]) * (K[i + 1] - K[i]) for i in range(len(K) - 1)) / 100
        # print(f'PA%K AUC: {auc}')

        # visual = check_graph(dist, attack, piece=4, threshold=threshold)
        # figure_path = os.path.join(self.savedir, 'fig')
        # os.makedirs(figure_path, exist_ok=True)
        # visual.savefig(os.path.join(figure_path, f'whole.png'))

        # result save
        folder_path = os.path.join(self.savedir, 'results')
        os.makedirs(folder_path, exist_ok=True)

        np.save(os.path.join(folder_path, f'dist.npy'), dist)
        # np.save(os.path.join(folder_path, f'attack.npy'), attack)
        np.save(os.path.join(folder_path, f'pred.npy'), pred)

        return history

    def inference(self, test_loader, epoch, alpha=.5, beta=.5):
        """
        inference

        Parameters
        ----------
        test_loader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
            attack : ndarray
            attack labels
        epoch : int
            train epoch

        Returns
        -------

        """
        dist = []
        attack = []
        criterion = self._select_criterion()

        self.model.eval()

        with torch.no_grad():
            for batch_idx, batch in enumerate(test_loader):
                batch_x = batch['given'].float().to(self.device)
                batch_y = batch['answer'].float().to(self.device)

                if self.args.model in ['VTTSAT', 'VTTPAT']:
                    predictions, _ = self.model.forward(batch_x)
                    score = criterion(predictions, batch_y).cpu().detach().numpy()
                    dist.append(np.mean(score, axis=2))
                    
                else:
                    predictions = self.model.forward(batch_x)
                    score = criterion(predictions, batch_y).cpu().detach().numpy()
                    dist.append(np.mean(score, axis=2))

                attack.append(batch['attack'].reshape(-1, batch['attack'].shape[-1]).numpy())

        dist = np.concatenate(dist).flatten()
        attack = np.concatenate(attack).flatten()

        return dist, attack


    def interpret(self, test_loader):
        """
        interpret

        Parameters
        ----------
        test_loader
            given : torch.Tensor (shape=(batch size, window size, num of features))
                input time-series data
            ts : ndarray (shape=(batch size, window size))
                input timestamp
            answer : torch.Tensor (shape=(batch size, window size, num of features))
                target time-series data
                If model_type is prediction, window size is 1
            attack : ndarray
            attack labels

        Returns
        -------

        """
        if self.args.resume is not None:
            print(f'resume version{self.args.resume}')
            weights, start_epoch, self.args.lr, best_metrics = load_model(resume=self.args.resume,
                                                                          logdir=self.savedir)
            if isinstance(self.model, nn.DataParallel):
                self.model.module.load_state_dict(weights)
            else:
                self.model.load_state_dict(weights)

        model_optim = self._select_optimizer()
        criterion = self._select_criterion()
        hist = dict()
        
        scores, labels, tss, vattns, tattns = [], [], [], [], []
        
        self.model.eval()
        with torch.no_grad():
            for batch_idx, batch in enumerate(test_loader):
                batch_x = batch['given'].float().to(self.device)
                batch_y = batch['answer'].float().to(self.device)

                predictions, [prior_vattn, prior_tattn] = self.model(batch_x, use_attn=True)
                _, [post_vattn, post_tattn] = self.model(predictions, use_attn=True)
                
                score = criterion(predictions, batch_y).cpu().detach().numpy()
                scores.append(score)
                labels.append(label)
                tss.append(batch_idx)

                vattns.append([
                    [vt.cpu().detach().numpy() for vt in prior_vattn],
                    [vt.cpu().detach().numpy() for vt in post_vattn]
                ])
                tattns.append([
                    [tt.cpu().detach().numpy() for tt in prior_tattn],
                    [tt.cpu().detach().numpy() for tt in post_tattn]
                ])

        return scores, labels, tss, vattns, tattns
    

    def inference_unlabeled(self, test_loader):
        """
        Inference for unlabeled WaferDataset:
        - Saves prediction, actual, error (with masking), and wafer metadata to HDF5
        """
        preds = []
        actuals = []
        dist_per_feature = []

        lotids = []
        wafer_numbers = []
        step_nums = []
        masks=[]

        criterion = self._select_criterion()  # 반드시 reduction='none'으로 설정되어 있어야 함

        self.model.eval()
        with torch.no_grad():
            with tqdm(total=len(test_loader), desc="Inference (per-feature + metadata)", leave=True) as pbar:
                for batch in test_loader:
                    batch_x = batch['given'].float().to(self.device)        # (B, T, D)
                    batch_y = batch['answer'].float().to(self.device)       # (B, T, D)
                    batch_mask = batch['mask'].float().to(self.device)      # (B, T)

                    # Model forward with mask (if VTTSAT or VTTPAT)
                    if self.args.model in ['VTTSAT', 'VTTPAT']:
                        predictions, _ = self.model(batch_x, mask=batch_mask)
                    else:
                        predictions = self.model(batch_x)

                    # Loss 계산 (reconstruction error), masking 적용
                    raw_loss = criterion(predictions, batch_y)             # (B, T, D)
                    masked_loss = raw_loss * batch_mask.unsqueeze(-1)      # (B, T, D)
                    dist_per_feature.append(masked_loss.cpu().numpy())

                    # 결과 저장
                    preds.append(predictions.cpu().numpy())
                    actuals.append(batch_y.cpu().numpy())
                    lotids.extend(batch['lotid'])
                    wafer_numbers.extend(batch['wafer_number'])
                    step_nums.extend(batch['step_num'])
                    masks.extend(batch['mask'])

                    pbar.update(1)

        preds = np.concatenate(preds, axis=0)               # (N, T, D)
        actuals = np.concatenate(actuals, axis=0)           # (N, T, D)
        dist_per_feature = np.concatenate(dist_per_feature, axis=0)  # (N, T, D)

        lotids = np.array(lotids, dtype='S')                # str → byte로 저장
        wafer_numbers = np.array(wafer_numbers, dtype='S')
        step_nums = np.array(step_nums)
        masks = np.array(masks) 
        
        # 저장
        folder_path = os.path.join(self.savedir, 'results')
        os.makedirs(folder_path, exist_ok=True)
        save_path = os.path.join(folder_path, 'inference_result_with_metadata.h5')

        with h5py.File(save_path, 'w') as f:
            f.create_dataset('actual', data=actuals, compression='gzip')
            f.create_dataset('pred', data=preds, compression='gzip')
            f.create_dataset('dist_per_window_per_features', data=dist_per_feature, compression='gzip')
            f.create_dataset('lotid', data=lotids)
            f.create_dataset('wafer_number', data=wafer_numbers)
            f.create_dataset('step_num', data=step_nums)
            f.create_dataset('mask', data=np.concatenate(masks, axis=0))

        print(f"[✓] HDF5 with metadata saved to: {save_path}")
        return preds, actuals, dist_per_feature, lotids, wafer_numbers, step_nums


### STEP 9. XAI

<img src="../image/XAI.png" width="700">

In [None]:
from scipy.special import softmax
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib import pyplot as plt
import matplotlib.colors as mcl
%matplotlib inline

from utils.utils import load_model

- 모델 체크포인트 로드

In [None]:
args.resume = 5
model = build_model(args, model_params, savedir)
weights, _, _, _ = load_model(resume=args.resume,
                              logdir=savedir)
model.model.load_state_dict(weights)

In [None]:
def get_input_by_uid(preds, actuals, lotids, wafer_numbers, masks, target_uid):
    """
    주어진 UID에 해당하는 window들을 반환 (입력 데이터와 mask 모두)
    
    Args:
        preds, actuals: (N, T, D)
        masks: (N, T)
        lotids, wafer_numbers: 리스트 또는 배열
        target_uid: 'lotid-wafer_number' 형식의 문자열

    Returns:
        input_windows: torch.Tensor of shape (B, T, D)
        input_masks: torch.Tensor of shape (B, T)
    """
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))
    indices = np.where(unique_ids == target_uid)[0]
    if len(indices) == 0:
        raise ValueError(f"UID '{target_uid}'에 해당하는 window 없음")

    input_windows = actuals[indices]  # 예측 대상은 actual (input)
    input_masks = masks[indices]

    return torch.tensor(input_windows).float(), torch.tensor(input_masks).float()



In [None]:
def calc_diff_attns(prior, post, mask=None):
    """
    prior, post: shape (L, T, H, D, D)
    mask: shape (T,) or (1, T), where 1 indicates valid time step

    Returns:
        diff: shape (D, D) – 평균 diff (마스킹 적용됨)
    """
    diff = np.abs(post - prior)  # (L, T, H, D, D)

    if mask is not None:
        if mask.ndim == 2:
            mask = mask[0]  # shape: (T,)

        # broadcast mask to (L, T, H, D, D)
        mask_broadcast = mask[None, :, None, None, None]  # (1, T, 1, 1, 1)
        diff = np.where(mask_broadcast, diff, np.nan)     # 마스킹된 부분은 NaN

        diff = np.nanmean(diff, axis=(0, 1, 2))            # NaN 제외 평균
    else:
        diff = np.mean(diff, axis=(0, 1, 2))               # 전체 평균
    return diff  # shape: (D, D)


In [None]:
target_uid = highlight_uids["Mat_Irms"]["Right-Top"][0]  # 예시: 첫 번째 UID
input_tensor, mask_tensor = get_input_by_uid(preds, actuals, lotids, wafer_numbers, masks, target_uid)

model.model.eval()
with torch.no_grad():
    input_prior_pred, [input_prior_vattn, input_prior_tattn] = model.model(input_tensor.to(device), use_attn=True)
    input_post_pred, [input_post_vattn, input_post_tattn] = model.model(input_prior_pred, use_attn=True)

normal_prior_score = model._select_criterion()(input_tensor.to(device), input_prior_pred).cpu().detach().numpy().mean()
normal_post_score = model._select_criterion()(input_prior_pred, input_post_pred).cpu().detach().numpy().mean()
print(f'Prior score: {normal_prior_score:.6f}, Post score: {normal_post_score:.6f}')

input_prior_vattn = input_prior_vattn.cpu().detach().numpy()
input_post_vattn = input_post_vattn.cpu().detach().numpy()

input_diff = calc_diff_attns(input_prior_vattn, input_post_vattn, mask_tensor)


In [None]:
import os
import torch
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
import matplotlib.colors as mcl
from collections import defaultdict
def save_attn_diff_per_wafer_train(trainloader, model, feature_names, save_root='attn_diff_results_train', device='cuda'):
    """
    train셋의 웨이퍼들에 대한 attention diff heatmap 저장
    """
    os.makedirs(save_root, exist_ok=True)
    model.model.eval()

    # UID별 window들을 모음
    uid_to_windows = defaultdict(list)

    # window 단위 → UID 단위로 수집
    for i in range(len(trainloader.dataset)):
        sample = trainloader.dataset[i]
        lotid = sample['lotid']
        wafer = sample['wafer_number']
        mask = sample['mask']
        uid = f"{lotid}-{wafer}"


        uid_to_windows[uid].append(sample['given'])

    for uid, window_list in uid_to_windows.items():
        try:
            input_tensor = torch.tensor(np.stack(window_list)).float().to(device)  # (B, W, D)

            with torch.no_grad():
                prior_pred, [prior_vattn, _] = model.model(input_tensor, use_attn=True)
                post_pred, [post_vattn, _] = model.model(prior_pred, use_attn=True)

            prior_vattn = prior_vattn.cpu().detach().numpy()
            post_vattn = post_vattn.cpu().detach().numpy()
            diff = calc_diff_attns(prior_vattn, post_vattn, mask)  # (D, D)

            # 시각화 및 저장
            plt.figure(figsize=(12, 10))
            h, s, v = 24, 0.99, 0.99
            colors = [
                mcl.hsv_to_rgb((h/360, 0, v)),
                mcl.hsv_to_rgb((h/360, 0.5, v)),
                mcl.hsv_to_rgb((h/360, 1, v))
            ]
            cmap = LinearSegmentedColormap.from_list('my_cmap', colors, gamma=3)

            ax = sns.heatmap(
                diff,
                cmap=cmap,
                annot=True,
                fmt=".4f",
                annot_kws={"size": 8},
                xticklabels=feature_names,
                yticklabels=feature_names,
                cbar=True
            )

            ax.xaxis.tick_top()
            plt.xticks(rotation=90, ha='center', fontsize=10)
            plt.yticks(rotation=0, va='center', fontsize=10)
            plt.title(f'Change of Correlations: {uid}', fontsize=14)
            plt.tight_layout()

            save_path = os.path.join(save_root, f'{uid}_attn_diff.png')
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"[저장 완료] {save_path}")

        except Exception as e:
            print(f"[에러] {uid} 처리 중 오류 발생: {e}")


In [None]:
save_attn_diff_per_wafer_train(
    trainloader=trainloader,
    model=model,
    feature_names=feature_names,
    save_root='./logs/vtt_all_step/VTTSAT/version5/plots/attn_diff_results_train',
    device='cuda'
)

# 250521 피드백 후 자료만들기!

In [None]:
import os
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

def get_top_bottom_wafers_per_sensor(score_df, top_k=10, draw_plot=True, save_dir=None):
    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

    sensor_columns = [col for col in score_df.columns if col not in ['step', 'unique_id']]
    result = {}
    result_rows = []

    for sensor in sensor_columns:
        df_sorted = score_df[['unique_id', sensor]].sort_values(by=sensor)
        bottom_rows = df_sorted.head(top_k)
        top_rows = df_sorted.tail(top_k)

        bottom_ids = bottom_rows['unique_id'].tolist()
        top_ids = top_rows['unique_id'].tolist()

        result[sensor] = {
            'bottom': bottom_ids,
            'top': top_ids,
        }

        # 🔴 bottom k 추가
        for i in range(len(bottom_rows)):
            result_rows.append({
                'sensor': sensor,
                'rank_type': 'bottom',
                'unique_id': bottom_rows.iloc[i]['unique_id'],
                'score': bottom_rows.iloc[i][sensor]
            })

        # 🔵 top k 추가
        for i in range(len(top_rows)):
            result_rows.append({
                'sensor': sensor,
                'rank_type': 'top',
                'unique_id': top_rows.iloc[i]['unique_id'],
                'score': top_rows.iloc[i][sensor]
            })

        if draw_plot:
            plt.figure(figsize=(6, 2.5))

            others_mask = ~score_df['unique_id'].isin(top_ids + bottom_ids)
            sns.stripplot(x=score_df[others_mask][sensor], color='gray', size=2.5, jitter=0.25, label='others')
            sns.stripplot(x=bottom_rows[sensor], color='red', size=3, jitter=0.25, label='Lowest 10')
            sns.stripplot(x=top_rows[sensor], color='blue', size=3, jitter=0.25, label='Highest 10')


            plt.title(sensor)
            plt.xlabel("VTT Score")
            plt.yticks([])
            plt.legend(loc='upper right', fontsize=8)
            plt.tight_layout()

            if save_dir:
                plt.savefig(f"{save_dir}/{sensor}_score_dist.png", dpi=150)
            else:
                plt.show()

    if save_dir:
        df_result = pd.DataFrame(result_rows)
        df_result.to_csv(os.path.join(save_dir, 'highest_lowest_results.csv'), index=False)

    return result


top_bottom_ids_per_sensor = get_top_bottom_wafers_per_sensor(
    vtt_all_scores_df,
    top_k=10,
    draw_plot=True,
    save_dir="./logs/vtt_all_step/VTTSAT/version5/plots/score_dist"
)


### reference, highest, lowest 실제값 시계열 plot 그리기

In [None]:
import zipfile
import os

from glob import glob
from tqdm import tqdm
P_FG_ch1 = pd.read_parquet("drystrip_dataset/bigger_train_df.parquet")
interested_sensors = ['APC_Position', 'APC_Pressure', 'Gas1_Monitor', 'Gas6_Monitor', 'Mat_Irms', 'Mat_Phase','Mat_Vrms',
                      'Mat_VC1_Position', 'Mat_VC2_Position', 'SourcePwr_Read', 'Temp', 'Wall_Temp_Monitor']

filterd_df = pd.concat([P_FG_ch1[['time', 'lotid', 'wafer_number', 'Recipe_Step_Num']],P_FG_ch1[interested_sensors]], axis=1)
filterd_df = filterd_df.sort_values(by=["lotid", "wafer_number", "time"]).reset_index(drop=True)

In [None]:
grouped = list(filterd_df.groupby(['lotid', 'wafer_number'])) 
normal_groups = [group for group in grouped]
train_df = pd.concat([group[1] for group in normal_groups])
train_df_test = train_df.drop(columns='time')
train_df_test
train_df.groupby(['lotid','wafer_number','Recipe_Step_Num'])

from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
import numpy as np
sns.set_style('darkgrid')

scaler = MinMaxScaler()
train_data_combined = train_df.iloc[:,4:]
train_scaled = scaler.fit_transform(train_data_combined)
train_scaled = pd.concat([train_df.iloc[:,:4], pd.DataFrame(train_scaled, columns=interested_sensors)], axis=1)
train_scaled = train_scaled.drop(columns=['time'])

In [None]:
def highest_lowest_bottom_actuals_with_reference(
    actuals, lotids, wafer_numbers, feature_names,
    highest_lowest_df, reference_df,
    mask=None,
    output_dir="top_bottom_overlay_plots",
    ref_max_count=100
):
    os.makedirs(output_dir, exist_ok=True)

    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))

    D = actuals.shape[2]
    highest_lowest_df = highest_lowest_df.copy()

    for sensor in highest_lowest_df['sensor'].unique():
        if sensor not in feature_names:
            continue

        sensor_idx = feature_names.index(sensor)

        fig, ax = plt.subplots(figsize=(10, 4))
        ax.set_title(f"{sensor} | Lowest vs Highest actuals", fontsize=13)

        # reference
        if reference_df is not None and sensor in reference_df.columns:
            ref_grouped = reference_df.groupby(['lotid', 'wafer_number'])
            ref_keys = random.sample(list(ref_grouped.groups.keys()), min(ref_max_count, len(ref_grouped)))
            for key in ref_keys:
                ref_series = ref_grouped.get_group(key).reset_index(drop=True)[sensor].values
                ax.plot(ref_series, color='#999999', alpha=1.0, linewidth=1.0,
                        label='Reference' if key == ref_keys[0] else None)

        # top 10
        top_ids = highest_lowest_df[(highest_lowest_df['sensor'] == sensor) & (highest_lowest_df['rank_type'] == 'top')]['unique_id']
        for i, uid in enumerate(top_ids):
            idxs = np.where(unique_ids == uid)[0]
            if len(idxs) == 0:
                continue
            idx = idxs[0]

            if mask is not None:
                valid = mask[idx] == 1
                actual_seq = actuals[idx, valid, sensor_idx]
            else:
                actual_seq = actuals[idx, :, sensor_idx]

            ax.plot(actual_seq, color='blue', alpha=0.6, linewidth=0.8,
                    label='Highest 10 Actual' if i == 0 else None)

        # bottom 10
        bottom_ids = highest_lowest_df[(highest_lowest_df['sensor'] == sensor) & (highest_lowest_df['rank_type'] == 'bottom')]['unique_id']
        for i, uid in enumerate(bottom_ids):
            idxs = np.where(unique_ids == uid)[0]
            if len(idxs) == 0:
                continue
            idx = idxs[0]

            if mask is not None:
                valid = mask[idx] == 1
                actual_seq = actuals[idx, valid, sensor_idx]
            else:
                actual_seq = actuals[idx, :, sensor_idx]

            ax.plot(actual_seq, color='red', alpha=0.6, linewidth=0.8,
                    label='Lowest 10 Actual' if i == 0 else None)

        ax.set_xlabel('Time')
        ax.set_ylabel('Sensor Value')
        ax.legend(fontsize=8)
        ax.grid(True, linestyle='--', alpha=0.5)
        plt.tight_layout()

        save_path = os.path.join(output_dir, f'{sensor}_highest_lowest_actual_overlay.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

    return f"✅ 저장 완료: {output_dir}"



In [None]:
highest_lowest_df = pd.read_csv('./logs/vtt_all_step/VTTSAT/version5/plots/highest_lowest_results.csv')

highest_lowest_bottom_actuals_with_reference(
    actuals=actuals,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    feature_names=feature_names,
    highest_lowest_df=highest_lowest_df,
    reference_df=train_scaled,
    mask=masks,
    output_dir="./logs/vtt_all_step/VTTSAT/version5/plots/highest_lowest_overlay",
    ref_max_count=1163
)

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random

def highest_lowest_bottom_actuals_with_reference_custom_single(
    actuals, lotids, wafer_numbers, feature_names,
    highest_lowest_df, reference_df,
    mask=None,
    output_dir="custom_overlay_plots_single",
    ref_max_count=100
):
    os.makedirs(output_dir, exist_ok=True)

    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))

    D = actuals.shape[2]
    highest_lowest_df = highest_lowest_df.copy()

    sensor_groups = {
        'APC_Position': ['APC_Position', 'APC_Pressure', 'Mat_Irms', 'Mat_VC1_Position'],
        'APC_Pressure': ['APC_Pressure'],
        'Gas1_Monitor': ['Gas1_Monitor', 'Wall_Temp_Monitor'],
        'Gas6_Monitor': ['Gas6_Monitor', 'Mat_Irms', 'Mat_VC2_Position'],
        'Mat_Irms': ['Mat_Irms', 'APC_Pressure', 'Mat_Vrms', 'Wall_Temp_Monitor'],
        'Mat_Phase': ['Mat_Phase', 'Mat_VC2_Position'],
        'Wall_Temp_Monitor': ['Wall_Temp_Monitor'],
        'Temp': ['Temp','Wall_Temp_Monitor', 'Mat_Irms'],
        'SourcePwr_Read': ['APC_Pressure', 'SourcePwr_Read',  'Temp'],
        'Mat_Vrms': ['APC_Pressure', 'Mat_Vrms',  'Mat_VC1_Position', 'Temp' ],
        'Mat_VC2_Position': ['Mat_VC2_Position', 'Mat_Vrms','Wall_Temp_Monitor'],
        'Mat_VC1_Position': ['Mat_VC1_Position', 'APC_Pressure', 'Mat_Phase', 'Mat_VC2_Position', 'Temp', 'Wall_Temp_Monitor'],
    }

    for key_sensor, related_sensors in sensor_groups.items():
        fig, axes = plt.subplots(len(related_sensors), 1, figsize=(10, 3 * len(related_sensors)), sharex=True)

        if len(related_sensors) == 1:
            axes = [axes]

        for ax, sensor in zip(axes, related_sensors):
            if sensor not in feature_names:
                continue

            sensor_idx = feature_names.index(sensor)

            # reference
            if reference_df is not None and sensor in reference_df.columns:
                ref_grouped = reference_df.groupby(['lotid', 'wafer_number'])
                ref_keys = random.sample(list(ref_grouped.groups.keys()), min(ref_max_count, len(ref_grouped)))
                for key in ref_keys:
                    ref_series = ref_grouped.get_group(key).reset_index(drop=True)[sensor].values
                    ax.plot(ref_series, color='#999999', alpha=0.5, linewidth=1.0,
                            label='Reference' if key == ref_keys[0] else None)

            # Highest (1개만)
            top_ids = highest_lowest_df[
                (highest_lowest_df['sensor'] == key_sensor) &
                (highest_lowest_df['rank_type'] == 'top')
            ]['unique_id'].values[:1]

            for uid in top_ids:
                idxs = np.where(unique_ids == uid)[0]
                if len(idxs) == 0:
                    continue
                idx = idxs[0]
                if mask is not None:
                    valid = mask[idx] == 1
                    actual_seq = actuals[idx, valid, sensor_idx]
                else:
                    actual_seq = actuals[idx, :, sensor_idx]

                ax.plot(actual_seq, color='blue', alpha=0.6, linewidth=1.2,
                        label='Highest')

            # Lowest (1개만)
            bottom_ids = highest_lowest_df[
                (highest_lowest_df['sensor'] == key_sensor) &
                (highest_lowest_df['rank_type'] == 'bottom')
            ]['unique_id'].values[:1]

            for uid in bottom_ids:
                idxs = np.where(unique_ids == uid)[0]
                if len(idxs) == 0:
                    continue
                idx = idxs[0]
                if mask is not None:
                    valid = mask[idx] == 1
                    actual_seq = actuals[idx, valid, sensor_idx]
                else:
                    actual_seq = actuals[idx, :, sensor_idx]

                ax.plot(actual_seq, color='red', alpha=0.6, linewidth=1.2,
                        label='Lowest')

            ax.set_ylabel(sensor)
            ax.legend(fontsize=7)
            ax.grid(True, linestyle='--', alpha=0.5)

        axes[-1].set_xlabel('Time')
        fig.suptitle(f"{key_sensor} | Highest vs Lowest", fontsize=13)
        plt.tight_layout(rect=[0, 0.03, 1, 0.95])
        save_path = os.path.join(output_dir, f'{key_sensor}_context_overlay_single.png')
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

    return f"✅ 저장 완료: {output_dir}"


In [None]:
highest_lowest_bottom_actuals_with_reference_custom_single(
    actuals=actuals,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    feature_names=feature_names,
    highest_lowest_df=highest_lowest_df,
    reference_df=train_scaled,
    mask=masks,
    output_dir="./logs/vtt_all_step/VTTSAT/version5/plots/highest_lowest_context_overlay_single",
    ref_max_count=1163
)

### attention map difference 그리기

In [None]:
def plot_attn_diff_top_bottom_per_sensor(
    preds, actuals, lotids, wafer_numbers, masks,
    highest_lowest_df, model, feature_names,
    get_input_by_uid, calc_diff_attns,
    save_root='attn_diff_top_bottom', device='cuda'
):
    os.makedirs(save_root, exist_ok=True)
    model.model.eval()

    lotids = np.array(lotids).astype(str)
    wafer_numbers = np.array(wafer_numbers).astype(str)
    unique_ids = np.char.add(lotids, np.char.add("-", wafer_numbers))

    for sensor in highest_lowest_df['sensor'].unique():
        if sensor not in feature_names:
            continue

        sensor_df = highest_lowest_df[highest_lowest_df['sensor'] == sensor]
        top_uid = sensor_df[sensor_df['rank_type'] == 'top']['unique_id'].values[0]
        bottom_uid = sensor_df[sensor_df['rank_type'] == 'bottom']['unique_id'].values[0]
        sensor_idx = feature_names.index(sensor)

        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        uids = {'Lowest': bottom_uid, 'Highest': top_uid}
        diffs = {}

        # 1. 먼저 두 diff 계산
        for title, uid in uids.items():
            try:
                input_seq, input_mask = get_input_by_uid(
                    preds, actuals, lotids, wafer_numbers, masks, uid
                )
                input_seq = input_seq.to(device)
                mask_tensor = input_mask.cpu().numpy()

                with torch.no_grad():
                    prior_pred, [prior_vattn, _] = model.model(input_seq, use_attn=True)
                    post_pred, [post_vattn, _] = model.model(prior_pred, use_attn=True)

                prior_vattn = prior_vattn.cpu().detach().numpy()
                post_vattn = post_vattn.cpu().detach().numpy()
                diff = calc_diff_attns(prior_vattn, post_vattn, mask_tensor)
                diffs[title] = diff

            except Exception as e:
                print(f"[에러] {sensor} | {title} | {uid} 처리 실패: {e}")
                diffs[title] = None

        # 2. min, max 계산
        valid_diffs = [d for d in diffs.values() if d is not None]
        if not valid_diffs:
            continue

        global_min = min(d.min() for d in valid_diffs)
        global_max = max(d.max() for d in valid_diffs)

        # 3. colormap 설정
        h = 24
        colors = [
            mcl.hsv_to_rgb((h / 360, 0, 1)),
            mcl.hsv_to_rgb((h / 360, 0.5, 1)),
            mcl.hsv_to_rgb((h / 360, 1, 1))
        ]
        cmap = LinearSegmentedColormap.from_list('custom_cmap', colors, gamma=3)

        # 4. 히트맵 시각화
        for i, (title, diff) in enumerate(diffs.items()):
            if diff is None:
                continue

            hm  = sns.heatmap(
                diff,
                cmap=cmap,
                annot=True,
                fmt=".4f",
                annot_kws={"size": 4},
                xticklabels=feature_names,
                yticklabels=feature_names,
                cbar=True,
                ax=axes[i],
                vmin=global_min,
                vmax=global_max  # ✅ 고정 범위!
            )

            axes[i].set_title(f"{title}", fontsize=9)
            axes[i].tick_params(axis='x', labelrotation=90, labelsize=6)
            axes[i].tick_params(axis='y', labelsize=6)
            colorbar = hm.collections[0].colorbar
            colorbar.ax.tick_params(labelsize=5)  # ← 숫자 크기 8pt로 설정

        plt.suptitle(f"{sensor} | Attention Difference (Lowest vs Highest)", fontsize=11)
        plt.subplots_adjust(wspace=0.3, hspace=0.3)
        save_path = os.path.join(save_root, f"{sensor}_attn_diff_top_bottom.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close(fig)

    return f"✅ 저장 완료: {save_root}"


In [None]:
plot_attn_diff_top_bottom_per_sensor(
    preds=preds,
    actuals=actuals,
    lotids=lotids,
    wafer_numbers=wafer_numbers,
    masks=masks,
    highest_lowest_df=highest_lowest_df,
    model=model,
    feature_names=feature_names,
    get_input_by_uid=get_input_by_uid,
    calc_diff_attns=calc_diff_attns,
    save_root='./logs/vtt_all_step/VTTSAT/version5/plots/attn_diff_highest_lowest',
    device='cuda'
)


- 원본 데이터에서 추출한 attention map과 재구성된 데이터에서 추출한 attention map을 비교

---

### STEP 10. 명령어로 실행

```
python main.py \
--train \
--test \
--model VTTPAT \
--dataname SWaT \
--use_multi_gpu \
--devices 0,1
```

---