# Offline Eğitim

## Hazırlık

In [1]:
import sys
from pathlib import Path

root = Path.cwd().parents[0]
if root not in sys.path:
  sys.path.append(str(root))
  sys.path.append(str(root / 'src'))

TMP_DIR, DM_MDL_DIR = str(root / 'tmp'), str(root / 'serializedObjs' / 'dm')
import os
if not os.path.exists(TMP_DIR):
  os.makedirs(TMP_DIR)
if not os.path.exists(DM_MDL_DIR):
  os.makedirs(DM_MDL_DIR)

from src import utils
utils.seed_everything()

import pickle

from transformers import Trainer, TrainingArguments
from transformers.models.gpt2.configuration_gpt2 import GPT2Config

from src.decision_mamba import TrainableDM, DecisionMambaGymDataCollator

## Parametreler

In [2]:
DATA_IN = f'{TMP_DIR}/offline_data.pkl'
MDL_OUT = f'{DM_MDL_DIR}/test_offline'
FRACTION = 0.8

## Veri Okuma

In [3]:
if not os.path.exists(DATA_IN):
  raise FileExistsError(f"Dosya ({DATA_IN}) bulunamadı.")

with open(DATA_IN, 'rb') as f:
  gameReplayData = pickle.load(f)

dataset = gameReplayData[:int(FRACTION * len(gameReplayData))]

## Eğitim

In [4]:
#Training Arguments
training_args = TrainingArguments(
  output_dir=f"{TMP_DIR}/train_of/",
  remove_unused_columns=False,
  num_train_epochs=5, #30, # 120
  per_device_train_batch_size=64,
  learning_rate=1e-4,
  weight_decay=1e-4,
  warmup_ratio=0.1,
  optim="adamw_torch",
  max_grad_norm=0.25,
  report_to="none",
)

# Initialize and then train the model
collator = DecisionMambaGymDataCollator(dataset)
n_state, n_action, n_hidden = collator.state_dim, collator.act_dim, 64
config = GPT2Config(
      vocab_size=1, # doesn't matter -- we don't use the vocab
      n_embd=n_hidden,
      n_positions=n_action,
      drop_p=0.1,
      n_layer=6, #12
      n_inner=4, #4
      max_ep_len=n_state,
      state_dim=n_state,
      act_dim=n_action,
      action_tanh=True,
      remove_act_embs=True)
model = TrainableDM(config)

trainer = Trainer(
  model=model,
  args=training_args,
  train_dataset=dataset,
  data_collator=collator,
)

trainer.train()

Could not estimate the number of tokens of the input, floating-point operations will not be computed


Step,Training Loss


TrainOutput(global_step=10, training_loss=2.0070676803588867, metrics={'train_runtime': 45.5456, 'train_samples_per_second': 8.782, 'train_steps_per_second': 0.22, 'total_flos': 0.0, 'train_loss': 2.0070676803588867, 'epoch': 5.0})

## Kayıt

In [5]:
trainer.save_model(MDL_OUT)