# Online Eğitim

## Hazırlık

In [1]:
import sys
from pathlib import Path

root = Path.cwd().parents[0]
if root not in sys.path:
  sys.path.append(str(root))
  sys.path.append(str(root / 'src'))

TMP_DIR, DM_MDL_DIR = str(root / 'tmp'), str(root / 'serializedObjs' / 'dm')
import os
if not os.path.exists(TMP_DIR):
  os.makedirs(TMP_DIR)
if not os.path.exists(DM_MDL_DIR):
  os.makedirs(DM_MDL_DIR)

from src import utils
utils.seed_everything()

import pickle

from tqdm import tqdm

import random

from transformers import Trainer, TrainingArguments

from src.decision_mamba import TrainableDM, DecisionMambaGymDataCollator

from src.env import Game, LEARNER, OPPONENT
from src.agents import MinimaxAgent, ModelAgent

## Değişkenler

In [2]:
DATA_FILE = f'{TMP_DIR}/offline_data.pkl'
MDL_IN = f'{DM_MDL_DIR}/test_offline'
MDL_OUT = f'{DM_MDL_DIR}/test_online'
ROUNDS = 2 #2000
#Top N
N = 7 #500
TARGET_RETURN = 2
EXPLORATORY_MAX = 0.5
EXPLORATORY_MIN = 0.1

## Veri / Model Okuma

In [3]:
if not os.path.exists(DATA_FILE):
  raise FileExistsError(f"Dosya ({DATA_FILE}) bulunamadı.")

with open(DATA_FILE, 'rb') as f:
  gameReplayData = pickle.load(f)

model = TrainableDM.from_pretrained(MDL_IN)

replay_buffer = sorted(gameReplayData, key=lambda x: x[4][0], reverse = True)[:N]

## Eğitim

In [4]:
training_args = TrainingArguments(
  output_dir=f"{TMP_DIR}/train_on/",
  remove_unused_columns=False,
  num_train_epochs=3, # 120
  per_device_train_batch_size=64, # 64
  learning_rate=1e-4,
  weight_decay=1e-4,
  warmup_ratio=0.1,
  optim="adamw_torch",
  max_grad_norm=0.25,
  disable_tqdm=True,
  report_to="none",
)

learner = ModelAgent(model, player=LEARNER)
opponent = MinimaxAgent(depth=2, epsilon=0.0, player=OPPONENT)
game = Game(learner, opponent)


for round in tqdm(range(1, ROUNDS)):
    replay_buffer = sorted(replay_buffer, key=lambda x: x[4][0], reverse = True)
    exploratory = EXPLORATORY_MIN + (EXPLORATORY_MAX - EXPLORATORY_MIN) * (ROUNDS - float(round)) / ROUNDS
    random.seed(round)
    _, _, traj1 = game.play(explore=exploratory)
    _, _, traj2 = game.play(explore=exploratory)
    _, _, traj3 = game.play(explore=exploratory)
    replay_buffer[-1] = traj1
    replay_buffer[-2] = traj2
    replay_buffer[-3] = traj3
    trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=replay_buffer,
      data_collator=DecisionMambaGymDataCollator(replay_buffer),
    )
    trainer.train()

  0%|          | 0/1 [00:00<?, ?it/s]Could not estimate the number of tokens of the input, floating-point operations will not be computed
100%|██████████| 1/1 [00:02<00:00,  2.78s/it]

{'train_runtime': 2.0363, 'train_samples_per_second': 10.313, 'train_steps_per_second': 1.473, 'train_loss': 1.9670173327128093, 'epoch': 3.0}





## Kayıt

In [5]:
trainer.save_model(MDL_OUT)