In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
import os
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split

from model import Pixel2StateNet

In [2]:
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Device: ", DEVICE)
BATCH_SIZE = 32  
NUM_EPOCHS = 20

Device:  cuda:0


In [3]:
def set_seed(seed) -> None:
    '''
    Sets the seed for the environment for reproducibility.
    '''
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
def concatenate_state_space(state_space):
    '''
    Converts the OrderedDict to a single vector for the state
    space. This is for ease of processing when fed into the model. 

    Original dict =  
    data = OrderedDict([
        ('joint_angles', array([7 entries]),
        ('upright', 1 entry),
        ('target', array([3 entries])),
        ('velocity', array([13 entries]))
    ])
    '''
    arrays_list = []
    for key, value in state_space.items():
        if isinstance(value, np.ndarray):
            arrays_list.append(value)
        else:
            arrays_list.append(np.array([value]))

    vector = np.concatenate(arrays_list)
    return vector

In [5]:
set_seed(seed=42)

# Loading data
dataset_path_and_file = "dataset/augmented_camera_view/proprio_pixel_dataset-100k_2024-06-02_17-44-33.npz" 

print("Loading dataset")
dataset = np.load(dataset_path_and_file, allow_pickle=True)
dataset_images = dataset['frames']
dataset_proprios = dataset['observations']

# Converting to pandas dataframe 
print("Converting to pandas dataframe")
data = {
    'image': list(dataset_images),
    'state_space': list(dataset_proprios)
}
dataset_df = pd.DataFrame(data)

print("Converting state_space column of dataframe")
dataset_df['state_space'] = dataset_df['state_space'].apply(lambda x: concatenate_state_space(x))

print(dataset_df.head(5))

Loading dataset
Converting to pandas dataframe
Converting state_space column of dataframe
                                               image  \
0  [[[25, 52, 77], [25, 52, 77], [25, 52, 77], [2...   
1  [[[25, 52, 77], [25, 52, 77], [25, 52, 77], [2...   
2  [[[25, 52, 77], [25, 52, 77], [25, 52, 77], [2...   
3  [[[25, 52, 77], [25, 52, 77], [25, 52, 77], [2...   
4  [[[25, 52, 77], [25, 52, 77], [25, 52, 77], [2...   

                                         state_space  
0  [-0.2574390085273436, 0.0009264072188474451, 0...  
1  [-0.027611403723891537, 0.0414119146556137, -0...  
2  [-0.1885185001670155, 0.08072699084908999, 0.1...  
3  [-0.21707487284132598, 0.05409772032614757, 0....  
4  [-0.42672974533903885, 0.05006412443187066, 0....  
