## Imports

In [1]:
import pandas as pd
import torch
from pytorch_lightning import Trainer

## Concept

1. Static Metadata, Time-varying Past Inputs, Time-varying Future Inputs을 Input Data로 활용
2. 모든 Inputs은 변수 선택 단계(VSN)를 거쳐 정적 공변량과 Time-varying Inputs들이 함께 LSTM layer에 입력됨
3. Time-varying Past Inputs은 Encoder에, Time-varying Future Inputs은 Decoder에 입력
4. LSTM total layer outputs과 정적 공변량 데이터를 GRN layer에 합께 입력한 후, Masked Interpretable Multi-head attention layer에 입력
5. 최종적으로 LSTM decoder 산출 값과 GRN을 거친 attention layer의 산출 값을 결합하여 quantile forecase를 진행

## 사용법

In [2]:
#!pip install tensorflow
#!pip install tensorboard
#!pip install optuna statsmodels
#!pip install optuna-integration[pytorch_lightning]

In [3]:
import copy
from pathlib import Path
import warnings

import lightning.pytorch as pl
from lightning.pytorch.callbacks import EarlyStopping, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.tuner import Tuner
import numpy as np
import pandas as pd
import torch

from pytorch_forecasting import Baseline, TemporalFusionTransformer, TimeSeriesDataSet
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MAE, SMAPE, PoissonLoss, QuantileLoss
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters
from pytorch_forecasting.data.examples import get_stallion_data

### Preprocessing

In [12]:
from pytorch_forecasting.data.examples import get_stallion_data

data = get_stallion_data() ## pandas df

## add time index
data["time_idx"] = data["date"].dt.year*12 + data["date"].dt.month
data["time_idx"] -= data["time_idx"].min() ## monto 단위, 0 to fin, 시점 당 350개 데이터

## add additional features
data["month"] = data.date.dt.month.astype(str).astype("category")
data["log_volume"] = np.log(data.volume + 1e-8)
data["avg_volume_by_sku"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")
data["avg_volume_by_agency"] = data.groupby(["time_idx", "sku"], observed=True).volume.transform("mean")

## special days encoding
special_days = [
    "easter_day",
    "good_friday",
    "new_year",
    "christmas",
    "labor_day",
    "independence_day",
    "revolution_day_memorial",
    "regional_games",
    "fifa_u_17_world_cup",
    "football_gold_cup",
    "beer_capital",
    "music_fest",
]

## 이게 뭐임? 데이터프레임에 들어간 게 아니네?
data[special_days] = (
    data[special_days].apply(lambda x: x.map({0: "-", 1: x.name})).astype("category")
)

### Create dataset and dataloaders

In [13]:
max_prediction_length = 6   ## validation length
max_encoder_length = 24     ## encoding length
training_cutoff = data["time_idx"].max() - max_prediction_length ## 훈련 데이터 종료 시점

training = TimeSeriesDataSet(
    data[lambda x : x.time_idx <= training_cutoff],
    time_idx = "time_idx", ## integer. 0 >=
    target = "volume",
    group_ids = ["agency", "sku"], ## TimeSeries Group
    min_encoder_length = max_encoder_length // 2, ## 인코더 길이 길게 유지
    max_encoder_length = max_encoder_length,
    min_prediction_length = 1,
    max_prediction_length = max_prediction_length,
    static_categoricals = ["agency", "sku"], ## 분리되는 요소만 시점에 불변
    static_reals = ["avg_population_2017", "avg_yearly_household_income_2017"], ## 수치형 중에선 평균값만 불변
    time_varying_known_categoricals = ["special_days", "month"], ## 시간에 따라 바뀌는 카테고리
    variable_groups = {
        "special_days" : special_days
    }, ## 한 개의 변수로 설명 가능
    time_varying_known_reals = ["time_idx", "price_regular", "discount_in_percent"],
    time_varying_unknown_categoricals = [],
    time_varying_unknown_reals = [
        "volume",
        "log_volume",
        "industry_volume",
        "soda_volume",
        "avg_max_temp",
        "avg_volume_by_agency",
        "avg_volume_by_sku",
    ], ## 정보 유출?
    target_normalizer = GroupNormalizer(
        groups=["agency", "sku"], transformation="softplus"
    ),  ## use softplus and normalize by group -> ReLU랑 비슷함
    add_relative_time_idx = True,
    add_target_scales = True,
    add_encoder_length = True
)

## create validation set
validation = TimeSeriesDataSet.from_dataset(
    training, data, predict = True, stop_randomization = True
)

## create dataloaders for model
batch_size = 32
train_dataloader = training.to_dataloader(
    train = True, batch_size = batch_size, num_workers = 0
)
val_dataloader = validation.to_dataloader(
    train = False, batch_size = batch_size * 10, num_workers = 0
)

### Modeling

In [5]:
baseline_predictions = Baseline().predict(val_dataloader, return_y = True)
SMAPE()(baseline_predictions.output, baseline_predictions.y)

/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA A100-SXM4-80GB MIG 7g.80gb') that has Tensor Cores. To properly utilize 

2025-08-12 16:22:00.584750: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754983320.602240    1207 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754983320.607877    1207 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754983320.622003    1207 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754983320.622016    1207 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754983320.622018    1207 computation_placer.cc:177] computation placer alr

tensor(0.4709, device='cuda:0')

In [16]:
pl.seed_everything(42) ## setting pytorch lightning seed

trainer = pl.Trainer(
    accelerator = "gpu",
    gradient_clip_val = 0.1
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate = 0.03,
    hidden_size = 8,
    attention_head_size = 1,
    dropout = 0.1, ## 0.1 to 0.3
    hidden_continuous_size = 8, ## <= hidden_size
    loss = QuantileLoss(),
    optimizer = "ranger",
    # reduce_on_plateau_patience = 1000 ## 에폭 이후에도 validation loss 미감소시 사용
)

print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

Seed set to 42
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of parameters in network: 13.5k


In [15]:
## find optimal learning rate
from lightning.pytorch.tuner import Tuner

res = Tuner(trainer).lr_find(
    tft,
    train_dataloaders = train_dataloader,
    val_dataloaders = val_dataloader,
    max_lr = 10.0,
    min_lr = 1e-6
)

print(f"suggested learning rate: {res.suggestion()}")
fig = res.plot(show=True, suggest=True)
fig.show()

/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/loops/utilities.py:73: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
You are using a CUDA device ('NVIDIA A100-SXM4-80GB MIG 7g.80gb') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision


2025-08-12 20:12:01.645077: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1754997121.663897   38413 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754997121.670316   38413 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754997121.686391   38413 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754997121.686406   38413 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754997121.686408   38413 computation_placer.cc:177] computation placer alr

Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

### Train

In [8]:
early_stop_callback = EarlyStopping(
    monitor = "val_loss", min_delta = 1e-4, patience = 10, verbose = False, mode = "min"
)
lr_logger = LearningRateMonitor() ## log the learning rate
logger = TensorBoardLogger("lightning_logs") ## logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs = 50,
    accelerator = "gpu",
    enable_model_summary = True,
    gradient_clip_val = 0.1,
    limit_train_batches = 50, ## comment in for training, running validation every 30 batches
    callbacks = [lr_logger, early_stop_callback],
    logger = logger
)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate = 0.03,
    hidden_size = 16,
    attention_head_size = 2,
    dropout = 0.1,
    hidden_continuous_size = 8,
    loss = QuantileLoss(),
    log_interval = 10, ## 10 batchs마다 로그
    optimizer = "ranger",
    reduce_on_plateau_patience = 4
)

print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of parameters in network: 29.4k


/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.


In [None]:
## fit network
trainer.fit(
    tft,
    train_dataloaders = train_dataloader,
    val_dataloaders = val_dataloader
)

In [None]:
predictions = trainer.model.predict(val_dataloader, return_y = True, trainer_kwargs = dict(accelerator = "gpu"))
SMAPE()(predictions.output, predictions.y) ## Baseline보다 못한데요...? 아직 덜됐나...

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


tensor(0.6430, device='cuda:0')

### Hyperparameter Tuning

In [None]:
import pickle
from pytorch_forecasting.models.temporal_fusion_transformer.tuning import optimize_hyperparameters

## create study
study = optimize_hyperparameters(
    train_dataloader,
    val_dataloader,
    model_path="optuna_test",
    n_trials=200,
    max_epochs=50,
    gradient_clip_val_range=(0.01, 1.0),
    hidden_size_range=(8, 128),
    hidden_continuous_size_range=(8, 128),
    attention_head_size_range=(1, 4),
    learning_rate_range=(0.001, 0.1),
    dropout_range=(0.1, 0.3),
    trainer_kwargs=dict(limit_train_batches=30),
    reduce_on_plateau_patience=4,
    use_learning_rate_finder=False,  # use Optuna to find ideal learning rate or use in-built learning rate finder
)

## save study results
# with open("test_study.pkl", "wb") as f :
#     pickle.dump(study, f)
    
## show best hyperparameters
print(study.best_trial.params)

## Data

In [9]:
df_train = pd.read_csv("/root/Dacon_comp/2025 전력사용량 예측/train.csv")
df_test = pd.read_csv("/root/Dacon_comp/2025 전력사용량 예측/test.csv")
building_info = pd.read_csv("/root/Dacon_comp/2025 전력사용량 예측/building_info.csv")

df_train = df_train.rename({pre:new for pre, new in zip(df_train.columns, ["num_date_time", "build_num", "date", "temp", "precip", "wind", "humidity", "sunhour", "sunweight", "power"])}, axis = 1)
df_test = df_test.rename({pre:new for pre, new in zip(df_test.columns, ["num_date_time", "build_num", "date", "temp", "precip", "wind", "humidity", "sunhour", "sunweight"])}, axis = 1)

In [10]:
## information processing
building_info = building_info.replace("-", "0").rename({c:n for c, n in zip(building_info.columns, ["build_num", "build_type", "GFA", "CA", "solar_gen", "ESS", "PCS"])}, axis = 1)\
    .assign(solar_gen = lambda _df : _df.solar_gen.astype("float64"))\
    .assign(ESS = lambda _df : _df.ESS.astype("float64"))\
    .assign(PCS = lambda _df : _df.PCS.astype("float64"))
    
## one-hot encoding -> 안해도 됨
# building_info = pd.get_dummies(building_info, dtype = int)
# building_info = building_info.rename({c:f"type_{i}" for i, c in enumerate(building_info.columns[6:])}, axis = 1)

In [11]:
## dttm으로 타입 변경
df_train["date"] = pd.to_datetime(df_train.date)
df_test["date"] = pd.to_datetime(df_test.date)
holilist = ["06-06", "08-15"] ## 2024 추석은 9월 17일
start_date = df_train.date.min()

## train data
del_time = df_train.date - start_date
df_train["time_idx"] = del_time.dt.days*24 + del_time.dt.seconds//3600
df_train["month"] = df_train.date.dt.month.astype(str).astype("category")
df_train["wday"] = df_train.date.dt.weekday.astype(str).astype("category")
df_train["is_holiday"] = df_train.date.astype(str).str[5:10].map(lambda x : 1 if x in holilist else 0)
df_train = df_train.drop(["num_date_time", "date"], axis = 1)
df_train = pd.merge(df_train, building_info, on = "build_num")
df_train["build_num"] = df_train["build_num"].astype(str).astype("category")

## 일단 국경일 이외 special day는 다루지 않기로 함

## test data
del_time = df_test.date - start_date
df_test["time_idx"] = del_time.dt.days*24 + del_time.dt.seconds//3600
df_test["month"] = df_test.date.dt.month.astype(str).astype("category")
df_test["wday"] = df_test.date.dt.weekday.astype(str).astype("category")
df_test["is_holiday"] = 0 ## 8.25 이후 휴일 없음
df_test = df_test.drop(["num_date_time", "date"], axis = 1)
df_test = pd.merge(df_test, building_info, on = "build_num")
df_test["build_num"] = df_test["build_num"].astype(str).astype("category")

In [12]:
df_train.build_num.cat.categories

Index(['1', '10', '100', '11', '12', '13', '14', '15', '16', '17', '18', '19',
       '2', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '3',
       '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '4', '40',
       '41', '42', '43', '44', '45', '46', '47', '48', '49', '5', '50', '51',
       '52', '53', '54', '55', '56', '57', '58', '59', '6', '60', '61', '62',
       '63', '64', '65', '66', '67', '68', '69', '7', '70', '71', '72', '73',
       '74', '75', '76', '77', '78', '79', '8', '80', '81', '82', '83', '84',
       '85', '86', '87', '88', '89', '9', '90', '91', '92', '93', '94', '95',
       '96', '97', '98', '99'],
      dtype='object')

In [13]:
set(df_train.columns) - set(df_test.columns)

{'power', 'sunhour', 'sunweight'}

In [14]:
max_prediction_length = 120     ## validation length
max_encoder_length = 336        ## lookback (14일)
training_cutoff = 1919          ## 0 to 1919 (1920 periods)

training = TimeSeriesDataSet(
    df_train[df_train.time_idx <= training_cutoff],
    time_idx = "time_idx",
    target = "power",
    group_ids = ["build_num"],
    min_encoder_length = max_encoder_length // 2,
    max_encoder_length = max_encoder_length,
    min_prediction_length = 1,
    max_prediction_length = max_prediction_length,
    static_categoricals = ["build_num", "build_type"],
    static_reals = ["GFA", "CA", "solar_gen", "ESS", "PCS"],
    # variable_groups = {}, ## 아직은 없는듯?
    time_varying_known_categoricals = ["wday", "month"],
    time_varying_known_reals = ["temp", "precip", "wind", "humidity", "is_holiday"],
    time_varying_unknown_reals = ["sunhour", "sunweight"],
    target_normalizer = GroupNormalizer(groups = ["build_num"], transformation = "softplus"),
    add_relative_time_idx = True,
    add_target_scales = True, ## 정규화되지 않은 시계열의 중심과 스케일을 피쳐로
    add_encoder_length = True ## 인코더의 길이를 static feature에 추가
)

## create validation set
validation = TimeSeriesDataSet.from_dataset(
    training, df_train, predict = True, stop_randomization = True
)

## create dataloaders
batch_size = 128
train_dataloader = training.to_dataloader(
    train = True, batch_size = batch_size, num_workers = 8
)
val_dataloader = validation.to_dataloader(
    train = False, batch_size = batch_size, num_workers = 0
)

> 이거 임배딩 할 때, 입력순이 아닌 알파벳 순으로 이뤄지므로, 주의할 것

In [15]:
baseline_predictions = Baseline().predict(val_dataloader, return_y = True)
SMAPE()(baseline_predictions.output, baseline_predictions.y[0])*100

/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'loss' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['loss'])`.
/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/utilities/parsing.py:209: Attribute 'logging_metrics' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['logging_metrics'])`.
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/py

tensor(31.7957, device='cuda:0')

## Learning

In [None]:
pl.seed_everything(42)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate = 0.03,
    hidden_size = 128,
    attention_head_size = 4,
    dropout = 0.2,
    hidden_continuous_size = 16,
    loss = QuantileLoss(),
    log_interval = 10,
    optimizer = "ranger",
    # reduce_on_plateau_patience = 1000
)


## Trainer setting
early_stop_callback = EarlyStopping(
    monitor = "val_loss", min_delta = 1e-4, patience = 10, verbose = False, mode = "min"
)
lr_logger = LearningRateMonitor() ## log the learning rate
logger = TensorBoardLogger("lightning_logs") ## logging results to a tensorboard

trainer = pl.Trainer(
    max_epochs = 100,
    accelerator = "gpu",
    enable_model_summary = True,
    gradient_clip_val = 1.0,
    limit_train_batches = 50, ## comment in for training, running validation every 30 batches
    callbacks = [lr_logger, early_stop_callback],
    logger = logger
)

print(f"Number of parameters in network: {tft.size() / 1e3:.1f}k")

Seed set to 42
💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


Number of parameters in network: 284.1k


In [19]:
trainer.fit(tft, train_dataloader, val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name                               | Type                            | Params | Mode 
------------------------------------------------------------------------------------------------
0  | loss                               | QuantileLoss                    | 0      | train
1  | logging_metrics                    | ModuleList                      | 0      | train
2  | input_embeddings                   | MultiEmbedding                  | 2.2 K  | train
3  | prescalers                         | ModuleDict                      | 512    | train
4  | static_variable_selection          | VariableSelectionNetwork        | 26.9 K | train
5  | encoder_variable_selection         | VariableSelectionNetwork        | 27.4 K | train
6  | decoder_variable_selection         | VariableSelectionNetwork        | 20.4 K | train
7  | static_context_variable_selection  | GatedResidualNetwork            | 16.8 K | train
8  | static_context_initial_hidden_lstm |

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/root/anaconda3/envs/trch/lib/python3.12/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=127` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined

In [19]:
torch.save(trainer, "TFT_trainer.trch")

In [21]:
model = torch.load("TFT_trainer.trch", weights_only = False)

In [5]:
trainer = torch.load(".lr_find_8cb93e1c-284d-4a94-ad02-fa42882c201e.ckpt")

  trainer = torch.load(".lr_find_8cb93e1c-284d-4a94-ad02-fa42882c201e.ckpt")


In [11]:
trainer.keys()

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters', 'dataset_parameters', '__special_save__'])

In [24]:
trainer["state_dict"]

OrderedDict([('input_embeddings.embeddings.agency.weight',
              tensor([[ 1.9269e+00,  1.4873e+00,  9.0072e-01, -2.1055e+00,  6.7842e-01,
                       -1.2345e+00, -4.3067e-02, -1.6047e+00],
                      [-7.5214e-01,  1.6487e+00, -3.9248e-01, -1.4036e+00, -7.2788e-01,
                       -5.5943e-01, -7.6884e-01,  7.6245e-01],
                      [ 1.6423e+00, -1.5960e-01, -4.9740e-01,  4.3959e-01, -7.5813e-01,
                        1.0783e+00,  8.0080e-01,  1.6806e+00],
                      [ 1.2791e+00,  1.2964e+00,  6.1047e-01,  1.3347e+00, -2.3162e-01,
                        4.1759e-02, -2.5158e-01,  8.5986e-01],
                      [-1.3847e+00, -8.7124e-01, -2.2337e-01,  1.7174e+00,  3.1888e-01,
                       -4.2452e-01,  3.0572e-01, -7.7459e-01],
                      [-1.5576e+00,  9.9564e-01, -8.7979e-01, -6.0114e-01, -1.2742e+00,
                        2.1228e+00, -1.2347e+00, -4.8791e-01],
                      [-9.1382e-01,