## Generate jpeg human replay dataset for training visual encoder

In [1]:
import glob
import io
import pickle
from pathlib import Path
from random import shuffle

import cv2
from tqdm import tqdm
import numpy as np
import os


In [2]:
original_replay_folder_path = "/home/sukai/Downloads/MontezumaRevenge_human/atari_v1/screens/revenge"
RESIZE = 84
target_folder_relative_path = "human_replay_screenshot_jpeg"

def fromfile(file, dtype, count, *args, **kwargs):
    """Read from any file-like object into a numpy array
       im = io.BytesIO(f.read())
       m = fromfile(im, dtype=np.uint8, count=im.__sizeof__())"""

    itemsize = np.dtype(dtype).itemsize
    buffer = np.zeros(count * itemsize, np.uint8)
    bytes_read = -1
    offset = 0
    while bytes_read != 0:
        bytes_read = file.readinto(buffer[offset:])
        offset += bytes_read
    rounded_bytes = (offset // itemsize) * itemsize
    buffer = buffer[:rounded_bytes]
    buffer.dtype = dtype
    return buffer

def get_image_folder_and_name(path):
    folder_name, name = path.split('/')[-2:]
    name = name.split('.')[0]
    return folder_name, name
    


In [3]:
imagefilenames = glob.glob(original_replay_folder_path +"/**/*.png", recursive=True)

In [4]:
imagefilenames[0]

'/home/sukai/Downloads/MontezumaRevenge_human/atari_v1/screens/revenge/931/977.png'

In [5]:
get_image_folder_and_name(imagefilenames[0])

('931', '977')

In [6]:
for path in tqdm(imagefilenames):
    with open(path, 'rb') as f:
        im = io.BytesIO(f.read())
        im = fromfile(im, dtype=np.uint8, count=im.__sizeof__())
        im = cv2.imdecode(im, cv2.IMREAD_COLOR)
        # rescale
        im = cv2.resize(im, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA)
        # dont normalise otherwise the image is not readable
        # store as bmp
        # transfer clip_id to the right path name 1199/clip_105.mp4
        folder_name, name = get_image_folder_and_name(path)
        Path(os.path.join(target_folder_relative_path, folder_name)).mkdir(parents=True, exist_ok=True)
        
        store_pathname = os.path.join(target_folder_relative_path, folder_name, name+".jpeg")
        # print("store at", store_pathname)
        # mkdir
        if not cv2.imwrite(store_pathname, im):
            raise Exception("Could not write image")

100%|███████████████████████████████| 2361871/2361871 [31:33<00:00, 1247.27it/s]


## Generate GIF for training video diffusion 

In [7]:
imagefilefoldernames = glob.glob(original_replay_folder_path +"/*", recursive=True)

In [24]:
def sort_image_files(arr, folder_name):
    file_dict = dict()
    for p in arr:
        key = int(p.split('/')[-1].split('.png')[0])
        file_dict[key] = p
    last_frame_ind = max(list(file_dict.keys()))
    if last_frame_ind + 1 != len(file_dict): # include 0 ind
        print("last frame ind not equivalent to size")      
        print("folder is", folder_name)
        
    return file_dict, last_frame_ind

In [32]:
target_gif_folder_relative_path = "human_replay_screenshot_gif"


In [31]:
import imageio

In [38]:
# loop over folder 
cur_output_gif_ind = 0

for image_folder in tqdm(imagefilefoldernames):
    image_paths = glob.glob(image_folder+ '/*.png', recursive = True)
    file_dict, last_frame_ind = sort_image_files(image_paths, image_folder)
    for start_ind in range(0, last_frame_ind + 1-150, 150):
        frames = [] 
        for i in range(150):
            cur_ind = i + start_ind
            image_p = file_dict[cur_ind]
            with open(image_p, 'rb') as f:
                im = io.BytesIO(f.read())
                im = fromfile(im, dtype=np.uint8, count=im.__sizeof__())
                im = cv2.imdecode(im, cv2.IMREAD_COLOR)
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                # rescale
                im = cv2.resize(im, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA)
                frames.append(im)
        with imageio.get_writer(os.path.join(target_gif_folder_relative_path,f"{cur_output_gif_ind}.gif"), mode="I") as writer:
            for idx, frame in enumerate(frames):
                writer.append_data(frame)
            
            cur_output_gif_ind += 1
    
    for start_ind in range(75, last_frame_ind + 1-150, 150): # overlap and create another gif
        frames = [] 
        for i in range(150):
            cur_ind = i + start_ind
            image_p = file_dict[cur_ind]
            with open(image_p, 'rb') as f:
                im = io.BytesIO(f.read())
                im = fromfile(im, dtype=np.uint8, count=im.__sizeof__())
                im = cv2.imdecode(im, cv2.IMREAD_COLOR)
                im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                # rescale
                im = cv2.resize(im, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA)
                frames.append(im)
        with imageio.get_writer(os.path.join(target_gif_folder_relative_path,f"{cur_output_gif_ind}.gif"), mode="I") as writer:
            for idx, frame in enumerate(frames):
                writer.append_data(frame)
            
            cur_output_gif_ind += 1
        

100%|███████████████████████████████████████████████████████████████████████████████████| 652/652 [2:17:59<00:00, 12.70s/it]


## Produce Video MP4 Version

In [39]:
target_video_folder_relative_path = "human_replay_screenshot_video"

In [42]:
# loop over folder 
cur_output_gif_ind = 0

for image_folder in tqdm(imagefilefoldernames):
    image_paths = glob.glob(image_folder+ '/*.png', recursive = True)
    file_dict, last_frame_ind = sort_image_files(image_paths, image_folder)
    for start_ind in range(0, last_frame_ind + 1-150, 150):
        frames = [] 
        for i in range(150):
            cur_ind = i + start_ind
            image_p = file_dict[cur_ind]
            with open(image_p, 'rb') as f:
                im = io.BytesIO(f.read())
                im = fromfile(im, dtype=np.uint8, count=im.__sizeof__())
                im = cv2.imdecode(im, cv2.IMREAD_COLOR)
#                 im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                # rescale
                im = cv2.resize(im, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA)
                frames.append(im)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
        video = cv2.VideoWriter(os.path.join(target_video_folder_relative_path,f"{cur_output_gif_ind}.mp4"), fourcc, 60, (RESIZE, RESIZE))
        for idx, frame in enumerate(frames):
            video.write(frame)
        video.release()
        cur_output_gif_ind += 1
    
    for start_ind in range(75, last_frame_ind + 1-150, 150): # overlap and create another gif
        frames = [] 
        for i in range(150):
            cur_ind = i + start_ind
            image_p = file_dict[cur_ind]
            with open(image_p, 'rb') as f:
                im = io.BytesIO(f.read())
                im = fromfile(im, dtype=np.uint8, count=im.__sizeof__())
                im = cv2.imdecode(im, cv2.IMREAD_COLOR)
#                 im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
                # rescale
                im = cv2.resize(im, (RESIZE, RESIZE), interpolation = cv2.INTER_AREA)
                frames.append(im)
        fourcc = cv2.VideoWriter_fourcc(*'mp4v') 
        video = cv2.VideoWriter(os.path.join(target_video_folder_relative_path,f"{cur_output_gif_ind}.mp4"), fourcc, 60, (RESIZE, RESIZE))
        for idx, frame in enumerate(frames):
            video.write(frame)
        video.release()
        cur_output_gif_ind += 1
        

100%|█████████████████████████████████████████████████████████████████████████████████████| 652/652 [36:17<00:00,  3.34s/it]
