In [1]:
import os
import glob
import gc

import pandas as pd
import numpy as np
import torch
from torch.utils.data import DataLoader

from transformers import PatchTSTForPrediction

from datasets import Dataset

2025-11-04 11:12:45.067020: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
data = "coin"

output_dir = "saved_models"
log_dir = os.path.join('logstf', data)

os.makedirs(log_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok = True)

loss_name = "mape"

learning_rate = 5e-5
num_train_epochs = 400

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
## target domain
target_X = pd.read_csv(f"../data/{data}/train_input_7.csv").iloc[:, 1:].values.astype(np.float32)

np.random.seed(2)
random_indices1 = np.random.choice(pd.read_csv("../data/M4_train.csv").iloc[:, (1):].index,
                                   size=target_X.shape[0] * 20, replace=True)

X_data = pd.read_csv("../data/M4_train.csv").iloc[:, 1 + (24 * 0):].loc[random_indices1].values.astype(np.float32)
y_data = pd.read_csv("../data/M4_test.csv").iloc[:, 1:].loc[random_indices1].values.astype(np.float32)

In [4]:
## bootstrap
np.random.seed(42)
select = np.random.choice(len(X_data), size=len(X_data), replace=True)
X_bootstrap = X_data[select]
y_bootstrap = y_data[select]

val_split_index = int(len(X_bootstrap) * 0.8)

def to_tensor_and_reshape(array):
    result = torch.tensor(array)
    result = result.reshape(-1, result.shape[1], 1)

    return result

X_train, X_valid = to_tensor_and_reshape(X_bootstrap[:val_split_index]), to_tensor_and_reshape(X_bootstrap[val_split_index:])
y_train, y_valid = to_tensor_and_reshape(y_bootstrap[:val_split_index]), to_tensor_and_reshape(y_bootstrap[val_split_index:])

## setting dataloader
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size = 256, shuffle = True, num_workers = 16)

test_dataset = torch.utils.data.TensorDataset(X_valid, y_valid)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size = 256, num_workers = 16)

In [5]:
backbone_model = PatchTSTForPrediction.from_pretrained(os.path.join(output_dir, "PatchTSTBackbone")).to(device)

In [None]:
## custom loss function
def SMAPE(yhat, y):
    numerator = 100*torch.abs(y - yhat)
    denominator = (torch.abs(y) + torch.abs(yhat))/2
    smape = torch.mean(numerator / denominator)
    return smape

def MAPE(yhat, y):
    return torch.mean(100*torch.abs((y - yhat) / y))

class MASE(torch.nn.Module):
    def __init__(self, training_data, period = 1):
        super().__init__()
        self.scale = torch.mean(torch.abs(training_data[:, period:] - data[:, :-period]))    ## 모든 훈련 데이터에 대한 평균 스케일 계산
    
    def forward(self, yhat, y):
        error = torch.abs(y - yhat)
        return torch.mean(error / self.scale)

In [None]:
if loss_name == "mse":
    loss_fn = torch.nn.MSELoss()
    lr = learning_rate
else:
    lr = learning_rate*2
    if loss_name == "mae":
        loss_fn = torch.nn.L1Loss()
    elif loss_name == "SMAPE":
        loss_fn = SMAPE
    elif loss_name == "mape":
        loss_fn = MAPE
    elif loss_name == "MASE":
        loss_fn = MASE
    else:
        raise Exception("Your loss name is not valid.")

optimizer = torch.optim.AdamW(backbone_model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max = num_train_epochs)
log_data = []

## early stopping
PATIENCE = 15
best_val_loss = np.inf
patience_counter = 0

for epoc in range(num_train_epochs):
    backbone_model.train()

    total_train_loss = 0

    for X, y in train_dataloader:
        X, y = X.to(device), y.to(device)

        optimizer.zero_grad()
        yhat = backbone_model(X).prediction_outputs
        loss = loss_fn(yhat, y)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(backbone_model.parameters(), max_norm = 1.0)
        optimizer.step()

        total_train_loss += loss.item()*X.shape[0]

    avg_train_loss = total_train_loss / len(train_dataloader.dataset)

    backbone_model.eval()

    with torch.no_grad():
        yys = []
        yyhats = []

        for XX, yy in test_dataloader:
            XX = XX.to(device)
            yys.append(yy.to(device))
            yyhats.append(backbone_model(XX).prediction_outputs)

        yyhat = torch.concat(yyhats)
        yy = torch.concat(yys)

        val_loss = loss_fn(yyhat, yy)

    print(f"Epoch {epoc+1}/{num_train_epochs} | Train Loss: {avg_train_loss:.6f}\t\t Val Loss: {val_loss:.6f}")

    log_data.append({"epoch": epoc, "loss": avg_train_loss, "eval_loss": val_loss.item()})

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(backbone_model.state_dict(), os.path.join(output_dir, f"model_{loss_name}_{1}.pth"))
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= PATIENCE:
        break

    scheduler.step()

Epoch 1/400 | Train Loss: 15.086191		 Val Loss: 14.617880
Epoch 2/400 | Train Loss: 10.989111		 Val Loss: 9.811539
Epoch 3/400 | Train Loss: 10.246380		 Val Loss: 9.372941
Epoch 4/400 | Train Loss: 9.864548		 Val Loss: 9.472543
Epoch 5/400 | Train Loss: 9.473016		 Val Loss: 9.064425
Epoch 6/400 | Train Loss: 9.134879		 Val Loss: 9.048741
Epoch 7/400 | Train Loss: 8.998951		 Val Loss: 9.460318
Epoch 8/400 | Train Loss: 8.786859		 Val Loss: 8.704053
Epoch 9/400 | Train Loss: 8.625377		 Val Loss: 8.758979
Epoch 10/400 | Train Loss: 8.484347		 Val Loss: 8.544407
Epoch 11/400 | Train Loss: 8.401431		 Val Loss: 8.313139
Epoch 12/400 | Train Loss: 8.294785		 Val Loss: 8.572848
Epoch 13/400 | Train Loss: 8.267933		 Val Loss: 8.735223
Epoch 14/400 | Train Loss: 8.043027		 Val Loss: 8.438051
Epoch 15/400 | Train Loss: 7.887270		 Val Loss: 7.978919
Epoch 16/400 | Train Loss: 7.817568		 Val Loss: 7.972125
Epoch 17/400 | Train Loss: 7.734387		 Val Loss: 8.017379
Epoch 18/400 | Train Loss: 7.649780	

In [None]:
## save log
pd.DataFrame(log_data).to_csv(os.path.join(log_dir, f"pretrain_{loss_name}_model{1}.csv"))

## load best model
backbone_model.load_state_dict(torch.load(os.path.join(output_dir, f"model_{loss_name}_{1}.pth")))

<All keys matched successfully>

In [None]:
yyhats = []
yys = []

with torch.no_grad():
    for XX, yy in test_dataloader:
        XX = XX.to(device)
        yys.append(yy.to(device))
        yyhats.append(backbone_model(XX).prediction_outputs)

In [None]:
yyhat, yy = torch.concat(yyhats).squeeze(), torch.concat(yys).squeeze()

In [None]:
mseLoss = torch.nn.MSELoss()
maeLoss = torch.nn.L1Loss()

def smape(yy, yyhat):
    numerator = 100*abs(yy - yyhat)
    denominator = (abs(yy) + abs(yyhat))/2
    smape = torch.mean(numerator / denominator)
    return smape

print(f"test RMSE: {torch.sqrt(mseLoss(yyhat, yy))}")
print(f"test MAE: {maeLoss(yyhat, yy)}")
print(f"test SMAPE: {smape(yy, yyhat)}")

test RMSE: 500.53314208984375
test MAE: 140.67575073242188
test SMAPE: 3.8587536811828613


In [None]:
yyhat = pd.DataFrame(yyhat.to("cpu"))
yyhat.columns = [f"{i}A" for i in range(yyhat.shape[1])]
yy = pd.DataFrame(yy.to("cpu"))
yy.columns = [f"{i}B" for i in range(yyhat.shape[1])]

In [None]:
val_result = pd.concat([yyhat, yy], axis = 1).sort_index(axis = 1)
val_result.columns = [f"prediction_{(i+1)//2}" if i%2 == 1 else f"ground_truth_{(i+1)//2}" for i in range(1, val_result.shape[1]+1)]
val_result.to_csv(os.path.join(log_dir, f"prediction_val_results_{loss_name}_model{1}.csv"), index = False)

In [None]:
val_result

Unnamed: 0,prediction_1,ground_truth_1,prediction_2,ground_truth_2,prediction_3,ground_truth_3,prediction_4,ground_truth_4,prediction_5,ground_truth_5,...,prediction_20,ground_truth_20,prediction_21,ground_truth_21,prediction_22,ground_truth_22,prediction_23,ground_truth_23,prediction_24,ground_truth_24
0,1704.263672,1745.800049,1769.205322,1869.199951,1763.832275,1854.199951,1764.050781,1863.800049,1765.525879,1853.800049,...,1766.165405,1833.199951,1750.739136,1851.099976,1771.409790,1854.900024,1767.369995,1854.199951,1780.279663,1870.400024
1,9950.673828,9980.000000,10081.349609,10070.000000,10252.224609,10190.000000,10213.568359,10200.000000,10260.244141,10210.000000,...,10011.796875,9940.000000,9980.976562,9960.000000,9908.910156,9950.000000,9951.951172,9970.000000,10048.781250,10050.000000
2,1450.028076,1445.834961,1439.325073,1433.120972,1441.031250,1439.545044,1442.589844,1441.876953,1462.861450,1440.109009,...,1457.601562,1441.203003,1446.141113,1435.150024,1451.248779,1420.425049,1437.592407,1441.050049,1438.310181,1425.350952
3,6555.224609,6546.056641,6925.402344,6936.538574,6731.679199,6724.092285,6659.072754,6635.496582,6881.517090,6887.514160,...,6153.918457,6130.813477,6293.760254,6259.539551,6577.294922,6557.993164,6939.978027,6951.220703,7035.430176,7044.781250
4,4718.234375,4723.200195,5153.504883,5204.700195,5186.511719,5262.500000,5244.418457,5392.399902,5289.799805,5476.600098,...,4983.645508,5034.799805,5030.506348,5064.700195,5000.416992,5092.600098,5031.646484,5120.899902,5115.606445,5156.399902
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2887,8805.758789,8830.000000,8832.996094,8800.000000,8799.887695,8780.000000,8739.701172,8710.000000,8746.764648,8790.000000,...,8884.001953,8810.000000,8752.689453,8690.000000,8785.265625,8710.000000,8835.542969,8810.000000,8786.583984,8810.000000
2888,966.815552,958.630005,1080.357910,1087.250000,1074.203857,1063.640015,1067.055054,1064.489990,1044.267334,1038.270020,...,1053.053711,1048.619995,1097.981812,1077.099976,1081.737549,1073.660034,1085.221436,1079.089966,1094.151733,1105.900024
2889,5617.546387,5577.244629,6292.315430,6268.951660,6216.923828,6390.646484,6161.664062,6173.075195,5771.603516,5789.339844,...,5227.836914,5299.357910,4917.280762,4838.526367,4774.562500,4506.627441,5986.719727,6024.740234,5806.917969,5739.643066
2890,6584.265625,6588.406250,6386.041016,6459.303711,6267.056641,6185.271484,6358.790527,6149.212402,6393.826172,6352.244629,...,6459.753418,6698.416016,6488.013672,6632.667480,6554.808105,6713.057617,6506.551758,6679.114746,6379.812500,6442.051758


In [57]:
torch.cuda.empty_cache()
gc.collect()

204