In [1]:
import random
import sys
from typing import List, Optional, Union
import os 
import h5py

from einops import rearrange
import numpy as np
import torch
from tqdm import tqdm
import wandb

from agent import Agent
from dataset import EpisodesDataset
from envs import SingleProcessEnv, MultiProcessEnv
from episode import Episode
from utils import EpisodeDirManager, RandomHeuristic
from datetime import datetime, timedelta 
from pathlib import Path


class Radar_Collector: 
    def __init__(self, env: Union[SingleProcessEnv, MultiProcessEnv], dataset: EpisodesDataset, episode_dir_manager: EpisodeDirManager) -> None:
        self.env = env
        self.dataset = dataset
        self.episode_dir_manager = episode_dir_manager
        self.obs = self.env.reset()
        self.episode_ids = [None] * self.env.num_envs
        self.heuristic = RandomHeuristic(self.env.num_actions)


    ##########################################################


        


    @torch.no_grad()
    def collect(self, agent: Agent, epoch: int, epsilon: float, should_sample: bool, temperature: float, burn_in: int, *, num_steps: Optional[int] = None, num_episodes: Optional[int] = None, current_date: datetime, radar_dataset_frequency: str):
        self.current_date=current_date 
        self.radar_dataset_frequency=radar_dataset_frequency
        self.num_steps=num_steps 
        self.num_episodes=num_episodes
        assert self.env.num_actions == agent.world_model.act_vocab_size
        assert 0 <= epsilon <= 1
        print(num_steps)
        print(num_episodes)
        assert (num_steps is None) != (num_episodes is None)
        should_stop = lambda steps, episodes: steps >= num_steps if num_steps is not None else episodes >= num_episodes


        to_log = []
        steps, episodes = 0, 0
        returns = []
        observations, actions, rewards, dones = [], [], [], []

        burnin_obs_rec, mask_padding = None, None

        #################################################################################################
        ##  This part is about creating episde IDs    
        
        
        if set(self.episode_ids) != {None} and burn_in > 0:
            current_episodes = [self.dataset.get_episode(episode_id) for episode_id in self.episode_ids]
            segmented_episodes = [episode.segment(start=len(episode) - burn_in, stop=len(episode), should_pad=True) for episode in current_episodes]
            mask_padding = torch.stack([episode.mask_padding for episode in segmented_episodes], dim=0).to(agent.device)
            burnin_obs = torch.stack([episode.observations for episode in segmented_episodes], dim=0).float().div(255).to(agent.device)
            burnin_obs_rec = torch.clamp(agent.tokenizer.encode_decode(burnin_obs, should_preprocess=True, should_postprocess=True), 0, 1)

        agent.actor_critic.reset(n=self.env.num_envs, burnin_observations=burnin_obs_rec, mask_padding=mask_padding)
        pbar = tqdm(total=num_steps if num_steps is not None else num_episodes, desc=f'Experience collection ({self.dataset.name})', file=sys.stdout)
        ##################################################################################################
       

        #### These two lines should be taken from the CONFIGURATION FILES 
        #### the trainning data is between the 2008 and 2014, for validation data between 2015 and 2017, and for testing yeras 2018-2020
        #### For testing take it now from the user 

        # start_date = datetime(2009, 5, 1) 
        # radar_dataset_frequency = "Monthly"
        # radar_dataset_path
        next_date=current_date
        while not should_stop(steps, episodes):
            print("steps:", steps)
            print("episodes:", episodes)
            print(current_date)
            print(radar_dataset_frequency)
            output, next_date = self.Radar_Data_Loader(next_date, radar_dataset_frequency)
            print(next_date)
            daily_output=output
            daily_output = np.array(daily_output)
            observations = torch.tensor(daily_output)
            observations=observations.unsqueeze(3)
            observations= rearrange(torch.FloatTensor(observations).div(255), 'n h w c -> n c h w').to(agent.device)
           
            
            for i in range(len(output)):
                observation=observations[i]
                observation=observation.unsqueeze(0)
                #print(observation.size())
                # in the action production I have changed the number of input, output channel and the resolution in the config file of tokenizer 
                # In the actor critic I have changed two lines 
                # 1) self.conv1 = nn.Conv2d(3, 32, 3, stride=1, padding=1) by self.conv1 = nn.Conv2d(1, 32, 3, stride=1, padding=1)
                # 2)  assert inputs.ndim == 4 and inputs.shape[1:] == (3, 64, 64) by assert inputs.ndim == 4 and inputs.shape[1:] == (1, 256, 256) 
               
                act = agent.act(observation, should_sample=should_sample, temperature=temperature).cpu().numpy()


                _ , reward, done, _ = train_env.step(act)
                actions.append(act)
                rewards.append(reward)
                dones.append(done)


                new_steps = len(self.env.mask_new_dones)
                print("new_steps:", new_steps)
                pbar.update(new_steps if num_steps is not None else 0)
            if self.env.should_reset():
                self.add_experience_to_dataset(observations, actions, rewards, dones)

                new_episodes = self.env.num_envs
                print("new_episode:", new_episodes)
                episodes += new_episodes
                pbar.update(new_episodes if num_episodes is not None else 0)

                for episode_id in self.episode_ids:
                    episode = self.dataset.get_episode(episode_id)
                    self.episode_dir_manager.save(episode, episode_id, epoch)
                    metrics_episode = {k: v for k, v in episode.compute_metrics().__dict__.items()}
                    metrics_episode['episode_num'] = episode_id
                    metrics_episode['action_histogram'] = wandb.Histogram(np_histogram=np.histogram(episode.actions.numpy(), bins=np.arange(0, self.env.num_actions + 1) - 0.5, density=True))
                    to_log.append({f'{self.dataset.name}/{k}': v for k, v in metrics_episode.items()})
                    returns.append(metrics_episode['episode_return'])
                self.obs = self.env.reset()
                self.episode_ids = [None] * self.env.num_envs
                agent.actor_critic.reset(n=self.env.num_envs)
                observations, actions, rewards, dones = [], [], [], []
                # current_date=next_date
                # print(current_date)

            ## Add incomplete episodes to dataset, and complete them later.
            ## This is not used in our case 
            # if len(observations) > 0:
            #     self.add_experience_to_dataset(observations, actions, rewards, dones)

            agent.actor_critic.clear()

            metrics_collect = {
                '#episodes': len(self.dataset),
                '#steps': sum(map(len, self.dataset.episodes)),
            }
            if len(returns) > 0:
                metrics_collect['return'] = np.mean(returns)
            metrics_collect = {f'{self.dataset.name}/{k}': v for k, v in metrics_collect.items()}
            to_log.append(metrics_collect)
            current_date=next_date
            return to_log, next_date

            # observations.append(self.obs)
            # obs = rearrange(torch.FloatTensor(self.obs).div(255), 'n h w c -> n c h w').to(agent.device)
            # act = agent.act(obs, should_sample=should_sample, temperature=temperature).cpu().numpy()

            # if random.random() < epsilon:
            #     act = self.heuristic.act(obs).cpu().numpy()

            # self.obs, reward, done, _ = self.env.step(act)

            # actions.append(act)
            # rewards.append(reward)
            # dones.append(done)

            # new_steps = len(self.env.mask_new_dones)
            # steps += new_steps
            # pbar.update(new_steps if num_steps is not None else 0)

    ################################################################################
    #### This is the function that calls the radar dataset, this function takes the intial date with the radar dataset frequency as inputs
    #### and returns back the observations and the next date to start from it the next time as it depends how many steps the env will take everytime 

    def Radar_Data_Loader(self, start_date, radar_dataset_frequency):
        radar_dataset_path=Path('/home/hbi/RAD_NL25_RAP_5min') #### This one should be in config file 
        year, month, day = start_date.year, start_date.month, start_date.day
        output = []
        next_date = start_date
        
        if radar_dataset_frequency == "Monthly":
            # Determine the next month
            next_date = datetime(year, month + 1, 1) if month < 12 else datetime(year + 1, 1, 1)
            
            # Iterate over days in the current month
            while datetime(year, month, day) < next_date:
                # Iterate over hours from 00:00 to 23:55
                for hour in range(0, 24):
                    # Iterate over minutes from 00:00 to 23:55 in 5-minute intervals
                    for minute in range(0, 60, 5):
                        # Construct the file path
                        file_name = f'RAD_NL25_RAP_5min_{year}{month:02d}{day:02d}{hour:02d}{minute:02d}.h5'
                        root_direction = os.path.join(radar_dataset_path, str(year), f'{month:02d}', file_name)

                        # Load and process the image data
                        image = np.array(h5py.File(root_direction)['image1']['image_data'])
                        image = image[264:520, 242:498]
                        image[image == 65535] = 0
                        image = image.astype('float32')
                        image = image / 100 * 12
                        image = np.clip(image, 0, 128)
                        image = image / 40
                        output.append(image)
                
                start_date += timedelta(days=1)
                year, month, day = start_date.year, start_date.month, start_date.day
                
        elif radar_dataset_frequency == "Weekly":
            # Determine the next week
            next_date = start_date + timedelta(weeks=1)
            
            # Iterate over days in the current week
            while start_date < next_date:
                # Iterate over hours from 00:00 to 23:55
                for hour in range(0, 24):
                    # Iterate over minutes from 00:00 to 23:55 in 5-minute intervals
                    for minute in range(0, 60, 5):
                        # Construct the file path
                        file_name = f'RAD_NL25_RAP_5min_{year}{month:02d}{day:02d}{hour:02d}{minute:02d}.h5'
                        root_direction = os.path.join(radar_dataset_path, str(year), f'{month:02d}', file_name)

                        # Load and process the image data
                        image = np.array(h5py.File(root_direction)['image1']['image_data'])
                        image = image[264:520, 242:498]
                        image[image == 65535] = 0
                        image = image.astype('float32')
                        image = image / 100 * 12
                        image = np.clip(image, 0, 128)
                        image = image / 40
                        output.append(image)
                
                start_date += timedelta(days=1)
                year, month, day = start_date.year, start_date.month, start_date.day
        
        elif radar_dataset_frequency == "Daily":
            # Iterate over hours from 00:00 to 23:55
            for hour in range(0, 24):
                # Iterate over minutes from 00:00 to 23:55 in 5-minute intervals
                for minute in range(0, 60, 5):
                    # Construct the file path
                    file_name = f'RAD_NL25_RAP_5min_{year}{month:02d}{day:02d}{hour:02d}{minute:02d}.h5'
                    root_direction = os.path.join(radar_dataset_path, str(year), f'{month:02d}', file_name)

                    # Load and process the image data
                    image = np.array(h5py.File(root_direction)['image1']['image_data'])
                    image = image[264:520, 242:498]
                    image[image == 65535] = 0
                    image = image.astype('float32')
                    image = image / 100 * 12
                    image = np.clip(image, 0, 128)
                    image = image / 40
                    output.append(image)
            
            next_date = start_date + timedelta(days=1)
        
        return output, next_date
        #######################################################################################################################################################

        #######################################################################################################################################################
        #### This function is changed to match the radar dataset as in iris they collect each observation, action reward and done and saved them as one episode 
        #### However, in our radar dataset we want to save them in a Daily, Weeekly and Monthly, thus each Day,Week or Month is an Epiosde. This means that we 
        #### don't have to iterate over the number of observations, rewards and dones to save the episode, but just convert the data to Tensors and save them 
        #### 

    def add_experience_to_dataset(self, observations: List[np.ndarray], actions: List[np.ndarray], rewards: List[np.ndarray], dones: List[np.ndarray]) -> None:
        observations=observations
        observations=observations.permute(1, 0, 2, 3) # Basically the channel should be first 
        actions = torch.LongTensor(actions)
        actions=actions.permute(1,0)
        rewards=torch.tensor(rewards)
        rewards=rewards.permute(1,0)
        ends=torch.LongTensor(dones)
        ends=ends.permute(1,0)
        dones=torch.LongTensor(dones)
        dones=dones.permute(1,0)
        dones_array = np.array(dones)  # Convert the list to a NumPy array
        mask_padding=torch.ones(dones_array.shape[1], dtype=torch.bool)
        mask_padding=mask_padding.unsqueeze(1)
        mask_padding=mask_padding.permute(1,0)
        assert len(observations) == len(actions) == len(rewards) == len(dones)
        episode = Episode(
            observations,  # channel-first
            actions,
            rewards,
            ends=dones,
            mask_padding=torch.ones(dones_array.shape[0], dtype=torch.bool),
        )
        ## This should be CHANGED if we want to have another episode ids but it creates and works
        if self.episode_ids[0] is None:
            self.episode_ids[0] = self.dataset.add_episode(episode)
        else:
            self.dataset.update_episode(self.episode_ids[0], episode)
    ###################################################################################################################################################    
