In [1]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
from IPython.display import display, clear_output
from functools import partial

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from tqdm import tqdm

from src.config import Config
from src.episode import Episode
from src.episode_dataset import EpisodeSupervisedDataset
from src.rl_data_record import RLDataRecord
from src.policy_factory import PolicyMode, PolicyFactory
from src.reward_model import RewardModel
from src.pre_trainer import PreTrainer
from src.policy_model_utils import load_policy_model, save_policy_model, train_and_plot_policy, inference_and_plot_pre_train_policy, inference_and_plot_policy_v2
from src.utils import get_color, normalize_min_max, to_device_collate, top_k_sampling
from src.episode_batch_repeat_sampler import EpisodeBatchRepeatSampler

# Setup
---

In [None]:
config = Config()
reward_model = RewardModel(config=config)
test_policy = PolicyFactory.create(
    policy_mode=PolicyMode.TRANSFORMER_WITH_LATE_POSITION_FUSION, config=config
)

# Datasets

train_dataset = EpisodeSupervisedDataset(config=config, split="TRAIN")
print(f"train_dataset : {len(train_dataset)}")

test_dataset = EpisodeSupervisedDataset(config=config, split="TEST")
print(f"test_dataset : {len(test_dataset)}")

eval_dataset = EpisodeSupervisedDataset(config=config, split="EVAL")
print(f"eval_dataset : {len(eval_dataset)}")


# DataLoaders
def get_data_loader(
    dataset: EpisodeSupervisedDataset, batch_size: int, shuffle: bool = True
):
    to_device_collate_configurable = partial(to_device_collate, config.device)
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        collate_fn=to_device_collate_configurable,
    )
    print(f"data loader: {dataset.split}, {len(dataloader)}")
    return dataloader


train_dataloader = get_data_loader(
    dataset=train_dataset,
    batch_size=config.train_batch_size,
)

test_dataloader = get_data_loader(
    dataset=test_dataset, batch_size=config.test_batch_size, shuffle=False
)

eval_dataloader = get_data_loader(
    dataset=eval_dataset,
    batch_size=config.eval_batch_size,
)

In [None]:
episode_samples = 4
fig, axes = plt.subplots(
    nrows=episode_samples, ncols=3, squeeze=False, figsize=(15, 20)
)

cur_eidx = 0
for eidx in range(episode_samples):
    es = train_dataset.get_episode(cur_eidx)
    cur_eidx += len(es.best_path)
    es.viz(ax=axes[eidx][0], reward_model=reward_model)

    # Viz fov
    fov = es.fov(center_pos=es.agent.start_state.position())
    # print(f"fov: {fov.size()}, {fov}")
    # print(f"fov: {fov}")
    es.viz_fov(ax=axes[eidx][1])
    es.viz_optimal_path(ax=axes[eidx][2])

    es = train_dataset.get_episode(eidx)
    # print(f"best_path: {es.best_path}")

# Training Loop
---

In [None]:
policy = PolicyFactory.create(
    policy_mode=PolicyMode.TRANSFORMER_WITH_LATE_POSITION_FUSION, config=config
).to(config.device)
pre_trainer = PreTrainer(config=config, policy=policy)
pre_trainer.run(train_dataset=train_dataset, eval_dataset=eval_dataset)

In [None]:
# policy = PolicyFactory.create(
#     policy_mode=PolicyMode.TRANSFORMER_WITH_LATE_POSITION_FUSION, config=config
# ).to(config.device)

# optimizer = torch.optim.AdamW(policy.parameters(), lr=1e-4, weight_decay=0.01)

# # Instantiate the CrossEntropyLoss
# criterion = nn.CrossEntropyLoss()

# total_steps = config.epoches * len(train_dataloader)
# warmup_steps = 0
# # Cosine annealing scheduler
# learning_rate_scheduler = CosineAnnealingLR(optimizer, T_max=total_steps - warmup_steps)

In [None]:
# with tqdm(total=len(train_dataloader) * config.epoches) as pbar:
#     for _ in range(config.epoches):
#         for batch_data in train_dataloader:
#             batch_fov: torch.Tensor = batch_data["fov"]
#             batch_cur_position: torch.Tensor = batch_data["agent_current_pos"]
#             batch_target_position: torch.Tensor = batch_data["agent_target_pos"]
#             batch_best_next_pos: torch.Tensor = batch_data["best_next_pos"]
#             batch_best_next_action: torch.Tensor = batch_data["best_next_action"]
#             batch_logits = policy(
#                 batch_fov=batch_fov,
#                 batch_cur_position=batch_cur_position,
#                 batch_target_position=batch_target_position,
#             )
#             batch_probability = F.softmax(batch_logits, dim=1)
#             # print(f"batch_logits: {batch_logits.shape}")
#             # print(f"batch_probability: {batch_probability.shape}")

#             loss = criterion(batch_logits, batch_best_next_action)
#             optimizer.zero_grad()
#             loss.backward()
#             optimizer.step()
#             learning_rate_scheduler.step()
#             current_lr = learning_rate_scheduler.get_last_lr()[0]

#             pbar.set_description(f"Loss: {loss}, Lr: {current_lr}")
#             pbar.update(1)

# Eval Policy
---

In [None]:
reward_model = RewardModel(config=config)

inference_and_plot_pre_train_policy(
    config=config,
    dataset=test_dataset,
    dataloader=test_dataloader,
    policy=policy,
    steps=20,
)