# Training IRIS model ported into Hugging Face transformers

In this notebook, we are going to train IRIS RL agent from [Transformers are Sample-Efficient World Models paper](https://arxiv.org/abs/2209.00588) on Atari environments to play breakout or whatever environment id you give to `make_atari` function below. The paper presents IRIS (Imagination with auto-Regression over an Inner Speech) which is a reinforcement learning agent that trains within an imagined world model, which consists of a discrete autoencoder and an autoregressive Transformer. IRIS acquires behaviors by precisely simulating millions of trajectories.

The paper's approach frames dynamics learning as a sequence modeling problem, with an autoencoder constructing a language of image tokens and a Transformer orchestrating that language over time.

There is also a [medium blog post](https://medium.com/@cedric.vandelaer/paper-review-transformers-are-sample-efficient-world-models-d0f9144f9c09) to understand how the algorithm works.

I have ported IRIS into transformers and the PR is under review. The notebook acts as a testing example for the ported model before the final merge.

`Note`: The notebook will be a demo for the ported model when the [PR](https://github.com/huggingface/transformers/pull/30883) gets merged.

In HuggingFace Transformers, one can directly load any [IRIS checkpoint](https://huggingface.co/models?other=iris) for different Atari environments from the hub directly into a `IrisModel` from the PR.

<img src="https://huggingface.co/datasets/ruffy369/sample-images/resolve/main/iris_model_original_arch.png"
alt="drawing"/>

Image from the paper, illustrating the model's workings

## Set-up environment

Let's begin by installing the requirements for training IRIS from the original repository.

In [None]:
!git clone https://github.com/eloialonso/iris

In [None]:
cd /kaggle/working/iris


In [None]:
!pip3 install setuptools==66 #(gym(0.21)error sol)


In [None]:
!pip install git+https://github.com/openai/gym.git@9180d12e1b66e7e2a1a622614f787a6ec147ac40

In [None]:
!pip install -r requirements.txt

In [None]:
cd /kaggle/working

#### Now we have to reinstall transformers in dev mode for the ported model

In [None]:
# Ensure clean environment
!pip uninstall transformers -y
!pip uninstall tokenizers -y

# Clone the repository
!git clone  -b add_iris --single-branch https://github.com/RUFFY-369/transformers.git
%cd transformers

# Initialize submodules
!git submodule update --init --recursive

# Install in dev mode
!pip install -e ".[quality]"

# Clear cache
!find . -type d -name "__pycache__" -exec rm -r {} +

# Restart runtime (do this manually or via menu)


#### `Restart kernel after the above cell execution`

## Hold onto your neurons🤗here comes the code 

`Note:` The code is from the original repositroy to keep the training of the ported model as accurate as possible 

In [None]:
cd /kaggle/working

### Performing some necessary imports

In [None]:
from __future__ import annotations
from typing import Any,Dict, List, Optional, Tuple,List, Union
import gym
import numpy as np
from PIL import Image
from dataclasses import dataclass
from pathlib import Path
import warnings
import torch

from collections import deque
import math
from pathlib import Path
import random

import psutil
import transformers as t
from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset, DataLoader
import sys

from einops import rearrange
from tqdm import tqdm
import wandb


### `make_atari` is a function that creates and configures an Atari environment with optional preprocessing steps and returns sn instance of the environment with wrappers applied.

In [None]:
def make_atari(id, size=64, max_episode_steps=None, noop_max=30, frame_skip=4, done_on_life_loss=False, clip_reward=False):
    env = gym.make(id)
    assert 'NoFrameskip' in env.spec.id or 'Frameskip' not in env.spec
    env = ResizeObsWrapper(env, (size, size))
    if clip_reward:
        env = RewardClippingWrapper(env)
    if max_episode_steps is not None:
        env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
    if noop_max is not None:
        env = NoopResetEnv(env, noop_max=noop_max)
    env = MaxAndSkipEnv(env, skip=frame_skip)
    if done_on_life_loss:
        env = EpisodicLifeEnv(env)
    return env

### `ResizeObsWrapper` is a Gym wrapper that resizes observations from an environment to a specified width and height while preserving color channels.

In [None]:
class ResizeObsWrapper(gym.ObservationWrapper):
    def __init__(self, env: gym.Env, size: Tuple[int, int]) -> None:
        gym.ObservationWrapper.__init__(self, env)
        self.size = tuple(size)
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(size[0], size[1], 3), dtype=np.uint8)
        self.unwrapped.original_obs = None

    def resize(self, obs: np.ndarray):
        img = Image.fromarray(obs)
        img = img.resize(self.size, Image.BILINEAR)
        return np.array(img)

    def observation(self, observation: np.ndarray) -> np.ndarray:
        self.unwrapped.original_obs = observation
        return self.resize(observation)



### `RewardClippingWrapper` is a Gym wrapper that clips rewards to -1 or 1 by applying the sign function, effectively normalizing all rewards.

In [None]:
class RewardClippingWrapper(gym.RewardWrapper):
    def reward(self, reward):
        return np.sign(reward)

### `NoopResetEnv` is a Gym wrapper that performs a random number of no-op actions (action 0) during environment reset to sample initial states, with the number of no-ops sampled from a range of 1 to `noop_max`.

In [None]:
class NoopResetEnv(gym.Wrapper):
    def __init__(self, env, noop_max=30):
        """Sample initial states by taking random number of no-ops on reset.
        No-op is assumed to be action 0.
        """
        gym.Wrapper.__init__(self, env)
        self.noop_max = noop_max
        self.override_num_noops = None
        self.noop_action = 0
        assert env.unwrapped.get_action_meanings()[0] == 'NOOP'

    def reset(self, **kwargs):
        """ Do no-op action for a number of steps in [1, noop_max]."""
        self.env.reset(**kwargs)
        if self.override_num_noops is not None:
            noops = self.override_num_noops
        else:
            noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
        assert noops > 0
        obs = None
        for _ in range(noops):
            obs, _, done, _ = self.env.step(self.noop_action)
            if done:
                obs = self.env.reset(**kwargs)
        return obs

    def step(self, action):
        return self.env.step(action)


### `EpisodicLifeEnv` is a Gym wrapper that treats loss of life as the end of an episode, while only resetting the environment when true game over occurs, facilitating more stable value estimation in reinforcement learning.

In [None]:
class EpisodicLifeEnv(gym.Wrapper):
    def __init__(self, env):
        """Make end-of-life == end-of-episode, but only reset on true game over.
        Done by DeepMind for the DQN and co. since it helps value estimation.
        """
        gym.Wrapper.__init__(self, env)
        self.lives = 0
        self.was_real_done = True

    def step(self, action):
        obs, reward, done, info = self.env.step(action)
        self.was_real_done = done
        # check current lives, make loss of life terminal,
        # then update lives to handle bonus lives
        lives = self.env.unwrapped.ale.lives()
        if lives < self.lives and lives > 0:
            # for Qbert sometimes we stay in lives == 0 condition for a few frames
            # so it's important to keep lives > 0, so that we only reset once
            # the environment advertises done.
            done = True
        self.lives = lives
        return obs, reward, done, info

    def reset(self, **kwargs):
        """Reset only when lives are exhausted.
        This way all states are still reachable even though lives are episodic,
        and the learner need not know about any of this behind-the-scenes.
        """
        if self.was_real_done:
            obs = self.env.reset(**kwargs)
        else:
            # no-op step to advance from terminal/lost life state
            obs, _, _, _ = self.env.step(0)
        self.lives = self.env.unwrapped.ale.lives()
        return obs



### `MaxAndSkipEnv` is a Gym wrapper that performs max pooling over the last two observations and skips frames by only returning every `skip`-th frame, while accumulating rewards over skipped frames.

In [None]:
class MaxAndSkipEnv(gym.Wrapper):
    def __init__(self, env, skip=4):
        """Return only every `skip`-th frame"""
        gym.Wrapper.__init__(self, env)
        assert skip > 0
        # most recent raw observations (for max pooling across time steps)
        self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=np.uint8)
        self._skip = skip
        self.max_frame = np.zeros(env.observation_space.shape, dtype=np.uint8)

    def step(self, action):
        """Repeat action, sum reward, and max over last observations."""
        total_reward = 0.0
        done = None
        for i in range(self._skip):
            obs, reward, done, info = self.env.step(action)
            if i == self._skip - 2:
                self._obs_buffer[0] = obs
            if i == self._skip - 1:
                self._obs_buffer[1] = obs
            total_reward += reward
            if done:
                break
        # Note that the observation on the done=True frame
        # doesn't matter
        self.max_frame = self._obs_buffer.max(axis=0)

        return self.max_frame, total_reward, done, info

    def reset(self, **kwargs):
        return self.env.reset(**kwargs)

### `DoneTrackerEnv` is a class that tracks the done status of multiple environments, monitoring if they are currently done, newly done, or already done, with methods to update and query the status.


In [None]:
class DoneTrackerEnv:
    def __init__(self, num_envs: int) -> None:
        """Monitor env dones: 0 when not done, 1 when done, 2 when already done."""
        self.num_envs = num_envs
        self.done_tracker = None
        self.reset_done_tracker()

    def reset_done_tracker(self) -> None:
        self.done_tracker = np.zeros(self.num_envs, dtype=np.uint8)

    def update_done_tracker(self, done: np.ndarray) -> None:
        self.done_tracker = np.clip(2 * self.done_tracker + done, 0, 2)

    @property
    def num_envs_done(self) -> int:
        return (self.done_tracker > 0).sum()

    @property
    def mask_dones(self) -> np.ndarray:
        return np.logical_not(self.done_tracker)

    @property
    def mask_new_dones(self) -> np.ndarray:
        return np.logical_not(self.done_tracker[self.done_tracker <= 1])


### `SingleProcessEnv` is a class for managing a single environment process that tracks done status, performs actions, and handles environment interactions including resetting, stepping, rendering, and closing.

In [None]:
class SingleProcessEnv(DoneTrackerEnv):
    def __init__(self, env_fn):
        super().__init__(num_envs=1)
        self.env = env_fn
        self.num_actions = self.env.action_space.n

    def should_reset(self) -> bool:
        return self.num_envs_done == 1

    def reset(self) -> np.ndarray:
        self.reset_done_tracker()
        obs = self.env.reset()
        return obs[None, ...]

    def step(self, action) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Any]:
        obs, reward, done, _ = self.env.step(action[0])  # action is supposed to be ndarray (1,)
        done = np.array([done])
        self.update_done_tracker(done)
        return obs[None, ...], np.array([reward]), done, None

    def render(self) -> None:
        self.env.render()

    def close(self) -> None:
        self.env.close()

### `EpisodeMetrics` is a dataclass that stores metrics for a single episode, including its length and total return.

In [None]:
@dataclass
class EpisodeMetrics:
    episode_length: int
    episode_return: float

### `Episode` is a dataclass that represents an episode with observations, actions, rewards, end markers, and padding masks, providing methods for length management, merging, segmenting, computing metrics, and saving to disk.

In [None]:
@dataclass
class Episode:
    observations: torch.ByteTensor
    actions: torch.LongTensor
    rewards: torch.FloatTensor
    ends: torch.LongTensor
    mask_padding: torch.BoolTensor

    def __post_init__(self):
        assert len(self.observations) == len(self.actions) == len(self.rewards) == len(self.ends) == len(self.mask_padding)
        if self.ends.sum() > 0:
            idx_end = torch.argmax(self.ends) + 1
            self.observations = self.observations[:idx_end]
            self.actions = self.actions[:idx_end]
            self.rewards = self.rewards[:idx_end]
            self.ends = self.ends[:idx_end]
            self.mask_padding = self.mask_padding[:idx_end]

    def __len__(self) -> int:
        return self.observations.size(0)

    def merge(self, other: Episode) -> Episode:
        return Episode(
            torch.cat((self.observations, other.observations), dim=0),
            torch.cat((self.actions, other.actions), dim=0),
            torch.cat((self.rewards, other.rewards), dim=0),
            torch.cat((self.ends, other.ends), dim=0),
            torch.cat((self.mask_padding, other.mask_padding), dim=0),
        )

    def segment(self, start: int, stop: int, should_pad: bool = False) -> Episode:
        assert start < len(self) and stop > 0 and start < stop
        padding_length_right = max(0, stop - len(self))
        padding_length_left = max(0, -start)
        assert padding_length_right == padding_length_left == 0 or should_pad

        def pad(x):
            pad_right = torch.nn.functional.pad(x, [0 for _ in range(2 * x.ndim - 1)] + [padding_length_right]) if padding_length_right > 0 else x
            return torch.nn.functional.pad(pad_right, [0 for _ in range(2 * x.ndim - 2)] + [padding_length_left, 0]) if padding_length_left > 0 else pad_right

        start = max(0, start)
        stop = min(len(self), stop)
        segment = Episode(
            self.observations[start:stop],
            self.actions[start:stop],
            self.rewards[start:stop],
            self.ends[start:stop],
            self.mask_padding[start:stop],
        )

        segment.observations = pad(segment.observations)
        segment.actions = pad(segment.actions)
        segment.rewards = pad(segment.rewards)
        segment.ends = pad(segment.ends)
        segment.mask_padding = torch.cat((torch.zeros(padding_length_left, dtype=torch.bool), segment.mask_padding, torch.zeros(padding_length_right, dtype=torch.bool)), dim=0)

        return segment

    def compute_metrics(self) -> EpisodeMetrics:
        return EpisodeMetrics(len(self), self.rewards.sum())

    def save(self, path: Path) -> None:
        torch.save(self.__dict__, path)


### `Collector` is a class for collecting and managing experience from an environment, integrating agent actions, observations, rewards, and done signals, and updating a dataset with the collected episodes while handling various stages of the collection process.

In [None]:
class Collector:
    def __init__(self, env: Union[SingleProcessEnv, MultiProcessEnv], dataset: EpisodesDataset, episode_dir_manager: EpisodeDirManager) -> None:
        self.env = env
        self.dataset = dataset
        self.episode_dir_manager = episode_dir_manager
        self.obs = self.env.reset()
        self.episode_ids = [None] * self.env.num_envs
        self.heuristic = RandomHeuristic(self.env.num_actions)

    @torch.no_grad()
    def collect(self, agent, epoch: int, epsilon: float, should_sample: bool, temperature: float, burn_in: int, *, num_steps: Optional[int] = None, num_episodes: Optional[int] = None):
        print(f"\nEpoch {epoch} / {actual_epochs}\n")
        assert self.env.num_actions == agent.world_model.act_vocab_size
        assert 0 <= epsilon <= 1

        assert (num_steps is None) != (num_episodes is None)
        should_stop = lambda steps, episodes: steps >= num_steps if num_steps is not None else episodes >= num_episodes

        to_log = []
        steps, episodes = 0, 0
        returns = []
        observations, actions, rewards, dones = [], [], [], []

        burnin_obs_rec, mask_padding = None, None
        if set(self.episode_ids) != {None} and burn_in > 0:
            current_episodes = [self.dataset[episode_id][0] for episode_id in self.episode_ids]
            segmented_episodes = [episode.segment(start=len(episode) - burn_in, stop=len(episode), should_pad=True) for episode in current_episodes]
            mask_padding = torch.stack([episode.mask_padding for episode in segmented_episodes], dim=0).to("cuda:0")
            burnin_obs = torch.stack([episode.observations for episode in segmented_episodes], dim=0).float().div(255).to("cuda:0")
            burnin_obs_rec = torch.clamp(agent.discrete_autoencoder.encode_decode(burnin_obs, should_preprocess=True, should_postprocess=True)[0], 0, 1)

        agent.actor_critic.reset(n=self.env.num_envs, burnin_observations=burnin_obs_rec, mask_padding=mask_padding)
        pbar = tqdm(total=num_steps if num_steps is not None else num_episodes, desc=f'Experience collection ({self.dataset.name})', file=sys.stdout)

        while not should_stop(steps, episodes):

            observations.append(self.obs)
            obs = rearrange(torch.FloatTensor(self.obs).div(255), 'n h w c -> n c h w').to("cuda:0")
            act = agent.act(obs, should_sample=should_sample, temperature=temperature).cpu().numpy()

            if random.random() < epsilon:
                act = self.heuristic.act(obs).cpu().numpy()

            self.obs, reward, done, _ = self.env.step(act)

            actions.append(act)
            rewards.append(reward)
            dones.append(done)

            new_steps = len(self.env.mask_new_dones)
            steps += new_steps
            pbar.update(new_steps if num_steps is not None else 0)

            # Warning: with EpisodicLifeEnv + MultiProcessEnv, reset is ignored if not a real done.
            # Thus, segments of experience following a life loss and preceding a general done are discarded.
            # Not a problem with a SingleProcessEnv.

            if self.env.should_reset():
                self.add_experience_to_dataset(observations, actions, rewards, dones)

                new_episodes = self.env.num_envs
                episodes += new_episodes
                pbar.update(new_episodes if num_episodes is not None else 0)

                for episode_id in self.episode_ids:
                    episode = self.dataset[episode_id][0]
                    self.episode_dir_manager.save(episode, episode_id, epoch)
                    metrics_episode = {k: v for k, v in episode.compute_metrics().__dict__.items()}
                    metrics_episode['episode_num'] = episode_id
                    metrics_episode['action_histogram'] = wandb.Histogram(np_histogram=np.histogram(episode.actions.numpy(), bins=np.arange(0, self.env.num_actions + 1) - 0.5, density=True))
                    to_log.append({f'{self.dataset.name}/{k}': v for k, v in metrics_episode.items()})
                    returns.append(metrics_episode['episode_return'])

                self.obs = self.env.reset()
                self.episode_ids = [None] * self.env.num_envs
                agent.actor_critic.reset(n=self.env.num_envs)
                observations, actions, rewards, dones = [], [], [], []

        # Add incomplete episodes to dataset, and complete them later.
        if len(observations) > 0:
            self.add_experience_to_dataset(observations, actions, rewards, dones)

        agent.actor_critic.clear()

    def add_experience_to_dataset(self, observations: List[np.ndarray], actions: List[np.ndarray], rewards: List[np.ndarray], dones: List[np.ndarray]) -> None:
        assert len(observations) == len(actions) == len(rewards) == len(dones)
        for i, (o, a, r, d) in enumerate(zip(*map(lambda arr: np.swapaxes(arr, 0, 1), [observations, actions, rewards, dones]))):  # Make everything (N, T, ...) instead of (T, N, ...)
            episode = Episode(
                observations=torch.ByteTensor(o).permute(0, 3, 1, 2).contiguous(),  # channel-first
                actions=torch.LongTensor(a),
                rewards=torch.FloatTensor(r),
                ends=torch.LongTensor(d),
                mask_padding=torch.ones(d.shape[0], dtype=torch.bool),
            )
            if self.episode_ids[i] is None:
                self.episode_ids[i] = self.dataset.add_episode(episode)
            else:
                self.dataset.update_episode(self.episode_ids[i], episode)


### `EpisodesDataset` is a subclass of `torch.utils.data.Dataset` that manages a collection of episodes. It provides methods to add, update, and retrieve episodes, track modifications, and save dataset checkpoints. Additionally, it supports dynamic updates for sequence length and sampling methods.


In [None]:
Batch = Dict[str, torch.Tensor]

class EpisodesDataset(Dataset):
    def __init__(self,  max_num_episodes: Optional[int] = None, name: Optional[str] = None,sequence_length: int = None, sample_from_start: bool = True, component: str = None) -> None:
        self.max_num_episodes = max_num_episodes
        self.name = name if name is not None else 'dataset'
        self.num_seen_episodes = 0
        self.episodes = deque()
        self.episode_id_to_queue_idx = dict()
        self.newly_modified_episodes, self.newly_deleted_episodes = set(), set()
        self.sample_from_start = sample_from_start
        self.sequence_length = sequence_length
        self.component = component

    def __len__(self) -> int:
        return len(self.episodes)
    
    def __getitem__(self, episode_id: int) -> Episode:
        assert episode_id in self.episode_id_to_queue_idx
        queue_idx = self.episode_id_to_queue_idx[episode_id]
        return (self.episodes[queue_idx],self.sample_from_start,self.sequence_length, self.component)
    
    def update_component(self,sequence_length, sample_from_start, component):
        self.sequence_length = sequence_length
        self.sample_from_start = sample_from_start
        self.component = component
        
    def update_dataset(self,agent, epoch, epsilon,should_sample,temperature, num_steps, burn_in):
        train_collector.collect(agent, epoch=epoch, epsilon=epsilon,should_sample= should_sample,temperature=temperature,num_steps= num_steps, burn_in=burn_in)

    def update_episode(self, episode_id: int, new_episode: Episode) -> None:
        assert episode_id in self.episode_id_to_queue_idx
        queue_idx = self.episode_id_to_queue_idx[episode_id]
        merged_episode = self.episodes[queue_idx].merge(new_episode)
        self.episodes[queue_idx] = merged_episode
        self.newly_modified_episodes.add(episode_id)

    def _append_new_episode(self, episode):
        episode_id = self.num_seen_episodes
        self.episode_id_to_queue_idx[episode_id] = len(self.episodes)
        self.episodes.append(episode)
        self.num_seen_episodes += 1
        self.newly_modified_episodes.add(episode_id)
        return episode_id

    def update_disk_checkpoint(self, directory: Path) -> None:
        assert directory.is_dir()
        for episode_id in self.newly_modified_episodes:
            episode = self[episode_id][0]
            episode.save(directory / f'{episode_id}.pt')
        for episode_id in self.newly_deleted_episodes:
            (directory / f'{episode_id}.pt').unlink()
        self.newly_modified_episodes, self.newly_deleted_episodes = set(), set()


### `EpisodesDatasetRamMonitoring` is a subclass of `EpisodesDataset` designed to manage episode storage within specified RAM usage limits. It dynamically adjusts the dataset size by removing older episodes if memory usage exceeds the set threshold, ensuring efficient RAM utilization.

In [None]:
class EpisodesDatasetRamMonitoring(EpisodesDataset):
    """
    Prevent episode dataset from going out of RAM.
    Warning: % looks at system wide RAM usage while G looks only at process RAM usage.
    """
    def __init__(self,max_ram_usage: str, name: Optional[str] = None,sequence_length: int = None, sample_from_start: bool = True, component: str = None) -> None:
        super().__init__(max_num_episodes=None, name=name, sequence_length= sequence_length,sample_from_start=sample_from_start,component=component)
        self.max_ram_usage = max_ram_usage
        self.num_steps = 0
        self.max_num_steps = None

        max_ram_usage = str(max_ram_usage)
        if max_ram_usage.endswith('%'):
            m = int(max_ram_usage.split('%')[0])
            assert 0 < m < 100
            self.check_ram_usage = lambda: psutil.virtual_memory().percent > m
        else:
            assert max_ram_usage.endswith('G')
            m = float(max_ram_usage.split('G')[0])
            self.check_ram_usage = lambda: psutil.Process().memory_info()[0] / 2 ** 30 > m

    def clear(self) -> None:
        super().clear()
        self.num_steps = 0

    def add_episode(self, episode: Episode) -> int:
        if self.max_num_steps is None and self.check_ram_usage():
            self.max_num_steps = self.num_steps
        self.num_steps += len(episode)
        while (self.max_num_steps is not None) and (self.num_steps > self.max_num_steps):
            self._popleft()
        episode_id = self._append_new_episode(episode)
        return episode_id

    def _popleft(self) -> Episode:
        episode = super()._popleft()
        self.num_steps -= len(episode)
        return episode

### `EpisodeDirManager` manages the storage of episodes in a specified directory. It ensures that no more than a maximum number of episodes are saved by removing the oldest episodes when necessary. It also tracks and saves the episode with the best return, updating it if a new episode exceeds the previous best.


In [None]:
class EpisodeDirManager:
    def __init__(self, episode_dir: Path, max_num_episodes: int) -> None:
        self.episode_dir = episode_dir
        self.episode_dir.mkdir(parents=False, exist_ok=True)
        self.max_num_episodes = max_num_episodes
        self.best_return = float('-inf')

    def save(self, episode: Episode, episode_id: int, epoch: int) -> None:
        if self.max_num_episodes is not None and self.max_num_episodes > 0:
            self._save(episode, episode_id, epoch)

    def _save(self, episode: Episode, episode_id: int, epoch: int) -> None:
        ep_paths = [p for p in self.episode_dir.iterdir() if p.stem.startswith('episode_')]
        assert len(ep_paths) <= self.max_num_episodes
        if len(ep_paths) == self.max_num_episodes:
            to_remove = min(ep_paths, key=lambda ep_path: int(ep_path.stem.split('_')[1]))
            to_remove.unlink()
        episode.save(self.episode_dir / f'episode_{episode_id}_epoch_{epoch}.pt')

        ep_return = episode.compute_metrics().episode_return
        if ep_return > self.best_return:
            self.best_return = ep_return
            path_best_ep = [p for p in self.episode_dir.iterdir() if p.stem.startswith('best_')]
            assert len(path_best_ep) in (0, 1)
            if len(path_best_ep) == 1:
                path_best_ep[0].unlink()
            episode.save(self.episode_dir / f'best_episode_{episode_id}_epoch_{epoch}.pt')


### `RandomHeuristic` generates random actions for a given number of possible actions. It produces random actions uniformly across all available actions for each observation in a batch.

In [None]:
class RandomHeuristic:
    def __init__(self, num_actions):
        self.num_actions = num_actions

    def act(self, obs):
        assert obs.ndim == 4  # (N, H, W, C)
        n = obs.size(0)
        return torch.randint(low=0, high=self.num_actions, size=(n,))
  

### `IrisModel` and `IrisConfig` are imported from the `transformers` library. `IrisModel` is a model class for the Iris architecture, while `IrisConfig` provides configuration settings for initializing `IrisModel`.


In [None]:
from transformers import IrisModel, IrisConfig

### `TrainableIRIS` is a subclass of `IrisModel` that extends the IRIS model to include a custom forward pass. It adds IRIS-specific loss calculations to the output and provides a method `original_forward` to access the base model's forward method without modifications.

In [None]:
class TrainableIRIS(IrisModel):
    def __init__(self, config):
        super().__init__(config)

    def forward(self, **kwargs):
        output = super().forward(**kwargs)
        # add the IRIS loss        
        loss = output.losses.loss_total 
        return {"loss": loss}

    def original_forward(self, **kwargs):
        return super().forward(**kwargs)

### `iris_gym_data_collator` is a function that prepares batches of data for training by sampling episodes and segments based on the specified component type. It collates and processes data from episodes, adjusting the number of samples according to the component type and ensuring correct sequence lengths and batch formatting.

In [None]:
from functools import partial
import random
from collections import deque
batch_num_samples_discrete_autoencoder = 128
batch_num_samples_world_model = 32
batch_num_samples_actor_critc = 32

def iris_gym_data_collator(samples):
    sampled_episodes_segments = []
    sample_from_start = samples[0][1]
    sequence_length = samples[0][2]
    episodes = [s[0] for s in samples]
    episodes = deque(episodes)
    
    component = samples[0][3]
    if component == "discrete_autoencoder":
        batch_num_samples = batch_num_samples_discrete_autoencoder
    elif component == "world_model":
        batch_num_samples = batch_num_samples_world_model
    else:
        batch_num_samples = batch_num_samples_actor_critc
        
    sampled_episodes = random.choices(episodes, k=batch_num_samples)
    for sampled_episode in sampled_episodes:
        if sample_from_start:
            start = random.randint(0, len(sampled_episode) - 1)
            stop = start + sequence_length
        else:
            stop = random.randint(1, len(sampled_episode))
            start = stop - sequence_length
        sampled_episodes_segments.append(sampled_episode.segment(start, stop, should_pad=True))
        assert len(sampled_episodes_segments[-1]) == sequence_length
    sampled_episodes_segments = [e_s.__dict__ for e_s in sampled_episodes_segments]
    batch = {}
    for k in sampled_episodes_segments[0]:
        batch[k] = torch.stack([e_s[k] for e_s in sampled_episodes_segments])
    batch['observations'] = batch['observations'].float() / 255.0  # int8 to float and scale
    return {
            "observations": batch['observations'],
            "actions": batch['actions'],
            "rewards": batch['rewards'],
            "ends": batch['ends'],
            "mask_padding": batch['mask_padding'],
            "component": component,
            "should_preprocess": True,
            "should_postprocess": True,
        }

### `CustomTrainer` class for model training. Subclassing Hugging Face’s Trainer was necessary to develop a custom trainer class, which allowed for following the original training loop described in the paper. Only the _inner_training_loop function was overridden to meet the specific requirements.

In [None]:
import time
from transformers.utils import logging,is_sagemaker_mp_enabled,is_torch_xla_available,is_accelerate_available,is_datasets_available
from transformers.trainer_utils import has_length,seed_worker
from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
from transformers.trainer_callback import TrainerState,ExportableState
from transformers.trainer_pt_utils import get_model_param_count
if is_accelerate_available():
    from accelerate.utils import DistributedType
if is_datasets_available():
    import datasets
logger = logging.get_logger(__name__)

world_model_start_after_epochs = 5
actor_critc_start_after_epochs = 10

num_env_steps = 200

sequence_length_comp = [1,20,21]
sample_from_start_comp = [True,True,False]
components_to_train = ["discrete_autoencoder", "world_model", "actor_critic"]

class CustomTrainer(Trainer):
    def __init__(self, model,args = None,data_collator = None,train_dataset = None,eval_dataset = None,tokenizer = None,
            model_init = None,compute_metrics = None,callbacks = None,optimizers = (None,None)):
            
            super().__init__(model, args, data_collator, train_dataset, eval_dataset, tokenizer, model_init,
              compute_metrics, callbacks, optimizers)
    
    def _inner_training_loop(
        self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None
    ):
        self.accelerator.free_memory()
        self._train_batch_size = batch_size
        if self.args.auto_find_batch_size:
            if self.state.train_batch_size != self._train_batch_size:
                from accelerate.utils import release_memory

                (self.model_wrapped,) = release_memory(self.model_wrapped)
                self.model_wrapped = self.model

                # Check for DeepSpeed *after* the intial pass and modify the config
                if self.is_deepspeed_enabled:
                    # Temporarily unset `self.args.train_batch_size`
                    original_bs = self.args.per_device_train_batch_size
                    self.args.per_device_train_batch_size = self._train_batch_size // max(1, self.args.n_gpu)
                    self.propagate_args_to_deepspeed(True)
                    self.args.per_device_train_batch_size = original_bs
            self.state.train_batch_size = self._train_batch_size
        logger.debug(f"Currently training with a batch size of: {self._train_batch_size}")
        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        if self.is_fsdp_xla_v2_enabled:
            train_dataloader = tpu_spmd_dataloader(train_dataloader)

        # Setting up training control variables:
        # number of training epochs: num_train_epochs
        # number of training steps per epoch: num_update_steps_per_epoch
        # total number of training steps to execute: max_steps
        total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size

        len_dataloader = None
        num_train_tokens = None
        if has_length(train_dataloader):
            len_dataloader = len(train_dataloader)
            num_update_steps_per_epoch = len_dataloader // args.gradient_accumulation_steps
            num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
            num_examples = self.num_examples(train_dataloader)
            if args.max_steps > 0:
                max_steps = args.max_steps
                num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
                    args.max_steps % num_update_steps_per_epoch > 0
                )
                # May be slightly incorrect if the last batch in the training dataloader has a smaller size but it's
                # the best we can do.
                num_train_samples = args.max_steps * total_train_batch_size
                if args.include_tokens_per_second:
                    num_train_tokens = (
                        self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
                    )
            else:
                max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
                num_train_epochs = math.ceil(args.num_train_epochs)
                num_train_samples = self.num_examples(train_dataloader) * args.num_train_epochs
                if args.include_tokens_per_second:
                    num_train_tokens = self.num_tokens(train_dataloader) * args.num_train_epochs
        elif args.max_steps > 0:  # Rely on max_steps when dataloader does not have a working size
            max_steps = args.max_steps
            # Setting a very large number of epochs so we go as many times as necessary over the iterator.
            num_train_epochs = sys.maxsize
            num_update_steps_per_epoch = max_steps
            num_examples = total_train_batch_size * args.max_steps
            num_train_samples = args.max_steps * total_train_batch_size
            if args.include_tokens_per_second:
                num_train_tokens = self.num_tokens(train_dataloader, args.max_steps) * args.gradient_accumulation_steps
        else:
            raise ValueError(
                "args.max_steps must be set to a positive value if dataloader does not have a length, was"
                f" {args.max_steps}"
            )

        if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
            if self.args.n_gpu > 1:
                # nn.DataParallel(model) replicates the model, creating new variables and module
                # references registered here no longer work on other gpus, breaking the module
                raise ValueError(
                    "Currently --debug underflow_overflow is not supported under DP. Please use DDP"
                    " (torchrun or torch.distributed.launch (deprecated))."
                )
            else:
                debug_overflow = DebugUnderflowOverflow(self.model)  # noqa

        delay_optimizer_creation = is_sagemaker_mp_enabled() or self.is_fsdp_xla_enabled or self.is_fsdp_enabled

        # We need to reset the scheduler, as its parameters may be different on subsequent calls
        if self._created_lr_scheduler:
            self.lr_scheduler = None
            self._created_lr_scheduler = False

        if self.is_deepspeed_enabled:
            self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps)

        if not delay_optimizer_creation:
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        self.state = TrainerState(
            stateful_callbacks=[
                cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
            ]
        )
        self.state.is_hyper_param_search = trial is not None
        self.state.train_batch_size = self._train_batch_size

        # Compute absolute values for logging, eval, and save if given as ratio
        if args.logging_steps is not None:
            if args.logging_steps < 1:
                self.state.logging_steps = math.ceil(max_steps * args.logging_steps)
            else:
                self.state.logging_steps = args.logging_steps
        if args.eval_steps is not None:
            if args.eval_steps < 1:
                self.state.eval_steps = math.ceil(max_steps * args.eval_steps)
            else:
                self.state.eval_steps = args.eval_steps
        if args.save_steps is not None:
            if args.save_steps < 1:
                self.state.save_steps = math.ceil(max_steps * args.save_steps)
            else:
                self.state.save_steps = args.save_steps

        # Activate gradient checkpointing if needed
        if args.gradient_checkpointing:
            if args.gradient_checkpointing_kwargs is None:
                gradient_checkpointing_kwargs = {}
            else:
                gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs

            self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)

        model = self._wrap_model(self.model_wrapped)

        # as the model is wrapped, don't use `accelerator.prepare`
        # this is for unhandled cases such as
        # FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
        use_accelerator_prepare = True if model is self.model else False

        if delay_optimizer_creation:
            if use_accelerator_prepare:
                self._fsdp_qlora_plugin_updates()
                self.model = self.accelerator.prepare(self.model)
            self.create_optimizer_and_scheduler(num_training_steps=max_steps)

        # prepare using `accelerator` prepare
        if use_accelerator_prepare:
            self.model.train()
            if hasattr(self.lr_scheduler, "step"):
                if self.use_apex:
                    model = self.accelerator.prepare(self.model)
                else:
                    model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer)
            else:
                # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config.
                model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
                    self.model, self.optimizer, self.lr_scheduler
                )
        elif self.args.optim in [OptimizerNames.LOMO, OptimizerNames.ADALOMO]:
            # In this case we are in DDP + LOMO, which should be supported
            self.optimizer = self.accelerator.prepare(self.optimizer)

        if self.is_fsdp_enabled:
            self.model = self.model_wrapped = model

        # for the rest of this function `model` is the outside model, whether it was wrapped or not
        if model is not self.model:
            self.model_wrapped = model

        # backward compatibility
        if self.is_deepspeed_enabled:
            self.deepspeed = self.model_wrapped

        # ckpt loading
        if resume_from_checkpoint is not None:
            if self.is_deepspeed_enabled:
                deepspeed_load_checkpoint(
                    self.model_wrapped, resume_from_checkpoint, load_module_strict=not _is_peft_model(self.model)
                )
            elif is_sagemaker_mp_enabled() or self.is_fsdp_enabled:
                self._load_from_checkpoint(resume_from_checkpoint, self.model_wrapped)

        # Check if saved optimizer or scheduler states exist
        self._load_optimizer_and_scheduler(resume_from_checkpoint)

        # important: at this point:
        # self.model         is the Transformers Model
        # self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model),
        # FSDP(Transformers Model), Dynamo Optimized Module(Transformers Model) etc.

        # Train!
        logger.info("***** Running training *****")
        logger.info(f"  Num examples = {num_examples:,}")
        logger.info(f"  Num Epochs = {num_train_epochs:,}")
        logger.info(f"  Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
        if self.args.per_device_train_batch_size != self._train_batch_size:
            logger.info(f"  Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}")
        logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}")
        logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
        logger.info(f"  Total optimization steps = {max_steps:,}")
        logger.info(f"  Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}")

        self.state.epoch = 0
        start_time = time.time()
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        steps_trained_progress_bar = None

        # Check if continuing training from a checkpoint
        if resume_from_checkpoint is not None and os.path.isfile(
            os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)
        ):
            self.state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME))
            self.compare_trainer_and_checkpoint_args(self.args, self.state)
            self._load_callback_state()
            epochs_trained = self.state.global_step // num_update_steps_per_epoch
            if not args.ignore_data_skip:
                steps_trained_in_current_epoch = self.state.global_step % (num_update_steps_per_epoch)
                steps_trained_in_current_epoch *= args.gradient_accumulation_steps
            else:
                steps_trained_in_current_epoch = 0

            logger.info("  Continuing training from checkpoint, will skip to saved global_step")
            logger.info(f"  Continuing training from epoch {epochs_trained}")
            logger.info(f"  Continuing training from global step {self.state.global_step}")
            if not args.ignore_data_skip:
                logger.info(
                    f"  Will skip the first {epochs_trained} epochs then the first"
                    f" {steps_trained_in_current_epoch} batches in the first epoch."
                )

        # Update the references
        self.callback_handler.model = self.model
        self.callback_handler.optimizer = self.optimizer
        self.callback_handler.lr_scheduler = self.lr_scheduler
        self.callback_handler.train_dataloader = train_dataloader
        if self.hp_name is not None and self._trial is not None:
            # use self._trial because the SigOpt/Optuna hpo only call `_hp_search_setup(trial)` instead of passing trial
            # parameter to Train when using DDP.
            self.state.trial_name = self.hp_name(self._trial)
        if trial is not None:
            assignments = trial.assignments if self.hp_search_backend == HPSearchBackend.SIGOPT else trial
            self.state.trial_params = hp_params(assignments)
        else:
            self.state.trial_params = None
        # This should be the same if the state has been saved but in case the training arguments changed, it's safer
        # to set this after the load.
        self.state.max_steps = max_steps
        self.state.num_train_epochs = num_train_epochs
        self.state.is_local_process_zero = self.is_local_process_zero()
        self.state.is_world_process_zero = self.is_world_process_zero()

        # tr_loss is a tensor to avoid synchronization of TPUs through .item()
        tr_loss = torch.tensor(0.0).to(args.device)
        # _total_loss_scalar is updated everytime .item() has to be called on tr_loss and stores the sum of all losses
        self._total_loss_scalar = 0.0
        self._globalstep_last_logged = self.state.global_step
        model.zero_grad()
        grad_norm: Optional[float] = None
        self.control = self.callback_handler.on_train_begin(args, self.state, self.control)

        if args.eval_on_start:
            self._evaluate(trial, ignore_keys_for_eval, skip_scheduler=True)

        total_batched_samples = 0
        
        component_idx = 0
        
        for epoch in range(epochs_trained, num_train_epochs):
            if epoch == 0:
                print(f'Training {components_to_train[component_idx]} for {num_env_steps} steps:')   
            
            epoch_iterator = train_dataloader
            if hasattr(epoch_iterator, "set_epoch"):
                epoch_iterator.set_epoch(epoch)

            # Reset the past mems state at the beginning of each epoch if necessary.
            if args.past_index >= 0:
                self._past = None

            steps_in_epoch = (
                len(epoch_iterator)
                if len_dataloader is not None
                else args.max_steps * args.gradient_accumulation_steps
            )
            self.control = self.callback_handler.on_epoch_begin(args, self.state, self.control)

            if epoch == epochs_trained and resume_from_checkpoint is not None and steps_trained_in_current_epoch == 0:
                self._load_rng_state(resume_from_checkpoint)

            rng_to_sync = False
            steps_skipped = 0
            if steps_trained_in_current_epoch > 0:
                epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch)
                steps_skipped = steps_trained_in_current_epoch
                steps_trained_in_current_epoch = 0
                rng_to_sync = True

            step = -1
            
            ##########################################component selection for training code##########################################
            if epoch%num_env_steps == 0 and epoch!=0 and int(epoch/num_env_steps) > world_model_start_after_epochs:
                component_idx += 1
                if component_idx == 1:
                    self.model.rl_agent.discrete_autoencoder.eval()
                    epoch_iterator.dataset.update_component(sequence_length_comp[component_idx], sample_from_start_comp[component_idx], components_to_train[component_idx])
                elif component_idx == 2 and int(epoch/num_env_steps) > actor_critc_start_after_epochs:
                    self.model.rl_agent.world_model.eval()
                    epoch_iterator.dataset.update_component(sequence_length_comp[component_idx], sample_from_start_comp[component_idx], components_to_train[component_idx])
                else:
                    component_idx = 0
                    epoch_iterator.dataset.update_component(sequence_length_comp[component_idx], sample_from_start_comp[component_idx], components_to_train[component_idx])
                    epoch_iterator.dataset.update_dataset(model.rl_agent, epoch=int((epoch/num_env_steps)+5), epsilon=0.01,should_sample= True,temperature=1.0,num_steps= 200, burn_in=20)
                    self.model.train()
            
                print(f'Training {components_to_train[component_idx]} for {num_env_steps} steps:')
            elif epoch%num_env_steps == 0 and  (int(epoch/num_env_steps) < world_model_start_after_epochs):
                epoch_iterator.dataset.update_dataset(model.rl_agent, epoch=int((epoch/num_env_steps)+5), epsilon=0.01,should_sample= True,temperature=1.0,num_steps= 200, burn_in=20)
                print(f'Training {components_to_train[component_idx]} for {num_env_steps} steps:')
            ##########################################################################################################################
            for step, inputs in enumerate(epoch_iterator):
                total_batched_samples += 1

                if self.args.include_num_input_tokens_seen:
                    main_input_name = getattr(self.model, "main_input_name", "input_ids")
                    if main_input_name not in inputs:
                        logger.warning(
                            "Tried to track the number of tokens seen, however the current model is "
                            "not configured properly to know what item is the input. To fix this, add "
                            "a `main_input_name` attribute to the model class you are using."
                        )
                    else:
                        input_device = inputs[main_input_name].device
                        self.state.num_input_tokens_seen += torch.sum(
                            self.accelerator.gather(
                                torch.tensor(inputs[main_input_name].numel(), device=input_device, dtype=torch.int64)
                            )
                        ).item()
                if rng_to_sync:
                    self._load_rng_state(resume_from_checkpoint)
                    rng_to_sync = False

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    if steps_trained_progress_bar is not None:
                        steps_trained_progress_bar.update(1)
                    if steps_trained_in_current_epoch == 0:
                        self._load_rng_state(resume_from_checkpoint)
                    continue
                elif steps_trained_progress_bar is not None:
                    steps_trained_progress_bar.close()
                    steps_trained_progress_bar = None

                if step % args.gradient_accumulation_steps == 0:
                    self.control = self.callback_handler.on_step_begin(args, self.state, self.control)

                with self.accelerator.accumulate(model):
                    tr_loss_step = self.training_step(model, inputs)

                if (
                    args.logging_nan_inf_filter
                    and not is_torch_xla_available()
                    and (torch.isnan(tr_loss_step) or torch.isinf(tr_loss_step))
                ):
                    # if loss is nan or inf simply add the average of previous logged losses
                    tr_loss += tr_loss / (1 + self.state.global_step - self._globalstep_last_logged)
                else:
                    if tr_loss.device != tr_loss_step.device:
                        raise ValueError(
                            f"Calculated loss must be on the original device: {tr_loss.device} but device in use is {tr_loss_step.device}"
                        )
                    tr_loss += tr_loss_step

                self.current_flos += float(self.floating_point_ops(inputs))

                is_last_step_and_steps_less_than_grad_acc = (
                    steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch
                )

                if (
                    total_batched_samples % args.gradient_accumulation_steps == 0
                    or
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    is_last_step_and_steps_less_than_grad_acc
                ):
                    # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered
                    # in accelerate. So, explicitly enable sync gradients to True in that case.
                    if is_last_step_and_steps_less_than_grad_acc:
                        self.accelerator.gradient_state._set_sync_gradients(True)

                    # Gradient clipping
                    if args.max_grad_norm is not None and args.max_grad_norm > 0:
                        # deepspeed does its own clipping

                        if is_sagemaker_mp_enabled() and args.fp16:
                            _grad_norm = self.optimizer.clip_master_grads(args.max_grad_norm)
                        elif self.use_apex:
                            # Revert to normal clipping otherwise, handling Apex or full precision
                            _grad_norm = nn.utils.clip_grad_norm_(
                                amp.master_params(self.optimizer),
                                args.max_grad_norm,
                            )
                        else:
                            _grad_norm = self.accelerator.clip_grad_norm_(
                                model.parameters(),
                                args.max_grad_norm,
                            )

                        if (
                            is_accelerate_available()
                            and self.accelerator.distributed_type == DistributedType.DEEPSPEED
                        ):
                            grad_norm = model.get_global_grad_norm()
                            # In some cases the grad norm may not return a float
                            if hasattr(grad_norm, "item"):
                                grad_norm = grad_norm.item()
                        else:
                            grad_norm = _grad_norm

                    self.optimizer.step()

                    self.control = self.callback_handler.on_optimizer_step(args, self.state, self.control)

                    optimizer_was_run = not self.accelerator.optimizer_step_was_skipped
                    if optimizer_was_run:
                        # Delay optimizer scheduling until metrics are generated
                        if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                            self.lr_scheduler.step()

                    model.zero_grad()
                    self.state.global_step += 1
                    self.state.epoch = epoch + (step + 1 + steps_skipped) / steps_in_epoch
                    self.control = self.callback_handler.on_step_end(args, self.state, self.control)

                    self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)
                else:
                    self.control = self.callback_handler.on_substep_end(args, self.state, self.control)

                if self.control.should_epoch_stop or self.control.should_training_stop:
                    # PyTorch/XLA relies on the data loader to insert the mark_step for
                    # each step. Since we are breaking the loop early, we need to manually
                    # insert the mark_step here.
                    if is_torch_xla_available():
                        xm.mark_step()
                    break
                 
            if step < 0:
                logger.warning(
                    "There seems to be not a single sample in your epoch_iterator, stopping training at step"
                    f" {self.state.global_step}! This is expected if you're using an IterableDataset and set"
                    f" num_steps ({max_steps}) higher than the number of available samples."
                )
                self.control.should_training_stop = True

            self.control = self.callback_handler.on_epoch_end(args, self.state, self.control)
            self._maybe_log_save_evaluate(tr_loss, grad_norm, model, trial, epoch, ignore_keys_for_eval)

            if DebugOption.TPU_METRICS_DEBUG in self.args.debug:
                if is_torch_xla_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if self.control.should_training_stop:
                break

        if args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        if args.load_best_model_at_end and self.state.best_model_checkpoint is not None:
            # Wait for everyone to get here so we are sure the model has been saved by process 0.
            if is_torch_xla_available():
                xm.rendezvous("load_best_model_at_end")
            elif args.parallel_mode == ParallelMode.DISTRIBUTED:
                dist.barrier()
            elif is_sagemaker_mp_enabled():
                smp.barrier()

            self._load_best_model()

        # add remaining tr_loss
        self._total_loss_scalar += tr_loss.item()
        effective_global_step = max(self.state.global_step, 0.001)  # Avoid ZeroDivisionError
        train_loss = self._total_loss_scalar / effective_global_step

        metrics = speed_metrics(
            "train",
            start_time,
            num_samples=num_train_samples,
            num_steps=self.state.max_steps,
            num_tokens=num_train_tokens,
        )
        self.store_flos()
        metrics["total_flos"] = self.state.total_flos
        metrics["train_loss"] = train_loss

        self.is_in_train = False

        self._memory_tracker.stop_and_update_metrics(metrics)

        self.log(metrics)

        run_dir = self._get_output_dir(trial)
        checkpoints_sorted = self._sorted_checkpoints(use_mtime=False, output_dir=run_dir)

        # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save.
        if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1:
            for checkpoint in checkpoints_sorted:
                if not os.path.samefile(checkpoint, self.state.best_model_checkpoint):
                    logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit")
                    shutil.rmtree(checkpoint)

        self.control = self.callback_handler.on_train_end(args, self.state, self.control)

        # Wait for the checkpoint to be uploaded.
        self._finish_current_push()

        # After training we make sure to retrieve back the original forward pass method
        # for the embedding layer by removing the forward post hook.
        if self.neftune_noise_alpha is not None:
            self._deactivate_neftune(self.model)

        return TrainOutput(self.state.global_step, train_loss, metrics)


### The code initializes a training environment for IRIS using `SingleProcessEnv` with the `BreakoutNoFrameskip-v4` Atari environment. Then we calculate the number of training epochs based on actual epochs and steps per epoch. Moving forward we set up directories for storing media and episodes, handling existing directories with warnings. The `EpisodeDirManager` and `EpisodesDatasetRamMonitoring` classes manage episode data and RAM usage, respectively, and the `Collector` class gathers data from the environment. 

### The model, `TrainableIRIS`, is configured and moved to the specified device. A pre-training phase updates the dataset with data collected by the model. Training arguments are defined for the Hugging Face `TrainingArguments`, specifying parameters like batch size, learning rate, and optimizer. A custom trainer class, `CustomTrainer`, is instantiated and used to train the model.


In [None]:
train_env = SingleProcessEnv(make_atari("BreakoutNoFrameskip-v4",64,20000,30,4,True,False))

actual_epochs = 600
num_steps = 200
num_train_epochs = int((actual_epochs*num_steps*len(components_to_train))/2.3)#2.3 is number of steps per epoch of trainer class(matching original code training)

media_dir = Path('media')
episode_dir = media_dir / 'episodes'
device = "cuda:0"

try:
    media_dir.mkdir(exist_ok=False, parents=False)
except FileExistsError:
    warnings.warn(f"Directory {media_dir} already exists, skipping creation.")

try:
    episode_dir.mkdir(exist_ok=False, parents=False)
except FileExistsError:
    warnings.warn(f"Directory {episode_dir} already exists, skipping creation.")
    
episode_manager_train = EpisodeDirManager(episode_dir / 'train', max_num_episodes=10)
global train_collector
train_dataset = EpisodesDatasetRamMonitoring("30G","train_dataset",sequence_length_comp[0] ,sample_from_start_comp[0],components_to_train[0])
train_collector = Collector(train_env, train_dataset, episode_manager_train)

model = TrainableIRIS(IrisConfig())
model.to(device)
start_after_collect_iterations = 5

for i in range (start_after_collect_iterations):
    train_dataset.update_dataset(model.rl_agent, epoch=i, epsilon=0.01,should_sample= True,temperature=1.0,num_steps= 200, burn_in=20)
    
training_args = TrainingArguments(
    output_dir="/kaggle/working/output/",
    remove_unused_columns=False,
    num_train_epochs=num_train_epochs,
    per_device_train_batch_size=128,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_ratio=0.1,
    optim="adamw_torch",
    max_grad_norm=10.0,
)

trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=iris_gym_data_collator,
)

trainer.train()

<img src="https://huggingface.co/ruffy369/iris-breakout/resolve/main/iris_trained_agent.gif"
alt="drawing"/>
