In [None]:
#Connect to Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

Mounted at /content/drive


In [None]:
#get to data
%cd drive/MyDrive/DLR/ProjectData

!pip install gymnasium[mujoco]

/content/drive/MyDrive/DLR/ProjectData
Collecting mujoco>=2.1.5 (from gymnasium[mujoco])
  Downloading mujoco-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.4/44.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
Collecting glfw (from mujoco>=2.1.5->gymnasium[mujoco])
  Downloading glfw-2.8.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl.metadata (5.4 kB)
Downloading mujoco-3.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.6/6.6 MB[0m [31m85.8 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading glfw-2.8.0-py2.py27.py3.py30.py31.py32.py33.py34.py35.py36.py37.py38.p39.p310.p311.p312.p313-none-manylinux_2_28_x86_64.whl (243 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m243.4/243.4 kB[0m [31m23.9 MB/s[0m eta [36m0:00:0

In [None]:
import torch
from torch import nn, zeros
from torch.optim import Adam
from torch.utils.tensorboard import SummaryWriter
from collections import deque
import random
import copy
import numpy as np
import glob

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

Using device: cuda


In [None]:
# @title Visualization code, sampled from HW2
#TODO: Set up Ant environment (or other, but ant seemed applicable)
import os
import gymnasium as gym

from gym.wrappers import RecordVideo
from IPython.display import Video, display, clear_output

# Force MuJoCo to use EGL for rendering (important for Colab)
os.environ["MUJOCO_GL"] = "egl"

def visualize(agent, planner):
    """Visualize agent with a custom camera angle."""

    # Create environment in rgb_array mode
    env = gym.make("InvertedPendulum-v5", render_mode="rgb_array", reset_noise_scale=0.2)

    # Apply video recording wrapper
    env = RecordVideo(env, video_folder="./", episode_trigger=lambda x: True)

    obs, _ = env.reset()

    # Access the viewer object through mujoco_py
    viewer = env.unwrapped.mujoco_renderer.viewer  # Access viewer
    viewer.cam.distance = 3.0     # Set camera distance
    viewer.cam.azimuth = 90       # Rotate camera around pendulum
    viewer.cam.elevation = 0   # Tilt the camera up/down

    hidden_state = None
    goal_state = torch.Tensor(obs).to(device)[np.newaxis, :]
    for _ in range(16):
        plan_mu, plan_sigma = planner.forward(torch.Tensor(obs).to(device)[np.newaxis, :], goal_state)
        plan_z = plan_mu + plan_sigma * torch.randn_like(plan_sigma)

        for t in range(32):
            with torch.no_grad():
                actions, hidden_state = agent.forward(torch.Tensor(obs).to(device)[np.newaxis, :], plan_z, goal_state)
                actions = actions.squeeze(0)
            obs, _, done, _= env.step(actions.cpu().numpy())
            if done:
                break
    env.close()

    # Display the latest video
    clear_output(wait=True)
    display(Video("./rl-video-episode-0.mp4", embed=True))

In [None]:
#Getting Data

class Dataset:
    def __init__(self):
        files = glob.glob('*.npy')
        self.data_list = [np.load(fil) for fil in files]
        #list of lists, outer list is each sequence, innerlists contain each step (action state pairs)
        self.grouped_data = np.zeros((5, 130, 5))
        for i in range(5):
            self.grouped_data[i] = self.data_list[i][0:130]

    def sample_batch(self, batch_size):
        run_num = np.random.randint(0, 5, size=batch_size)
        start_indices = np.random.randint(0, 98, size=batch_size)[:, np.newaxis]
        indices = start_indices + np.arange(32)[np.newaxis, :]

        sample = self.grouped_data[run_num[:, np.newaxis], indices]
        return torch.tensor(sample, dtype=torch.float32)

data = Dataset()

IndexError: list index out of range

In [None]:
import torch.nn.functional as F

class Actor(nn.Module):
    def __init__(self, obs_dim, act_dim, goal_dim, layer_size=1024, latent_dim=256):
        super(Actor, self).__init__()

        input_dim = obs_dim + latent_dim + goal_dim
        self.lstm1 = nn.LSTM(input_dim, layer_size, batch_first=True)
        self.lstm2 = nn.LSTM(layer_size, layer_size, batch_first=True)

        self.actions = nn.Linear(layer_size, act_dim)
    # def forward(self, obs, latent_plan, goal, hidden_state=None):
    #     x = torch.cat([obs, latent_plan, goal], dim=-1)
    def forward(self, obs, z, goal, hidden_state=None):
        x = torch.cat([obs, z, goal], dim=-1)

        x, hidden_state = self.lstm1(x, hidden_state)
        x, hidden_state = self.lstm2(x, hidden_state)

        return self.actions(x), hidden_state

class Encoder(nn.Module):
    def __init__(self, enc_in_dim, layer_size=2048, latent_dim=256, epsilon=1e-4):
        super(Encoder, self).__init__()

        self.epsilon = epsilon

        self.lstm1 = nn.LSTM(enc_in_dim, layer_size, batch_first=True, bidirectional=True)
        self.lstm2 = nn.LSTM(layer_size * 2, layer_size, batch_first=True, bidirectional=True)

        self.mu = nn.Linear(layer_size * 2, latent_dim)
        self.sigma = nn.Linear(layer_size * 2, latent_dim)

    def forward(self, x):
        x, _ = self.lstm1(x)
        x, _ = self.lstm2(x)

        mu = self.mu(x[:, -1, :])
        sigma = F.softplus(self.sigma(x[:, -1, :])) + self.epsilon
        # mu = self.mu(x[-1, :])
        # sigma = F.softplus(self.sigma(x[-1, :])) + self.epsilon

        sample = torch.randn_like(sigma)
        z = mu + sigma * sample

        return z, mu, sigma

class Planner(nn.Module):
    def __init__(self, obs_dim, goal_dim, layer_size=2048, latent_dim=256, epsilon=1e-4):
        super(Planner, self).__init__()

        self.epsilon = epsilon

        input_dim = obs_dim + goal_dim
        self.fc1 = nn.Linear(input_dim, layer_size)
        self.fc2 = nn.Linear(layer_size, layer_size)
        self.fc3 = nn.Linear(layer_size, layer_size)
        self.fc4 = nn.Linear(layer_size, layer_size)

        self.mu = nn.Linear(layer_size, latent_dim)
        self.sigma = nn.Linear(layer_size, latent_dim)

    def forward(self, obs_init, obs_goal):
        x = torch.cat([obs_init, obs_goal], dim=-1)

        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))

        mu = self.mu(x)
        sigma = F.softplus(self.sigma(x)) + self.epsilon

        return mu, sigma

In [None]:
import torch.distributions as dist
import torch.optim as optim

def train_sample(batch_size, beta, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer):
    sample = data.sample_batch(batch_size).to(device)
    current_state = sample[:, 0, :-1]
    current_action = sample[:, 0, -1]
    goal_state = sample[:, -1, :-1]
    goal_action = sample[:, -1, -1]

    z, mu_phi, sigma_phi = encoder.forward(sample)

    mu_psi, sigma_psi = planner.forward(current_state, goal_state)

    phi_gaussian = dist.Normal(mu_phi, sigma_phi)

    psi_gaussian = dist.Normal(mu_psi, sigma_psi)

    KL_loss = torch.sum(dist.kl.kl_divergence(phi_gaussian, psi_gaussian))

    policy_action, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

    action_loss = F.l1_loss(policy_action.squeeze(1), current_action.unsqueeze(1))

    loss = beta * KL_loss + action_loss

    encoder_optimizer.zero_grad()
    planner_optimizer.zero_grad()
    actor_optimizer.zero_grad()

    loss.backward()

    encoder_optimizer.step()
    planner_optimizer.step()
    actor_optimizer.step()
    return loss

In [None]:
latent_dim = 32

encoder = Encoder(5, layer_size=256, latent_dim=latent_dim).to(device)
planner = Planner(4, 4, layer_size=512, latent_dim=latent_dim).to(device)
actor = Actor(4, 1, 4, layer_size=512, latent_dim=latent_dim).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
planner_optimizer = optim.Adam(planner.parameters(), lr=1e-4)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)

In [None]:
for batch in range(10000):
    loss = train_sample(32, .9, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer)
    if batch % 100 == 0:
        print(f"Batch: {batch}, Loss: {loss}")

Batch: 0, Loss: 3.3969194889068604
Batch: 100, Loss: 0.5342366695404053
Batch: 200, Loss: 0.2794186770915985
Batch: 300, Loss: 0.11472980678081512
Batch: 400, Loss: 0.08341763913631439
Batch: 500, Loss: 0.10533882677555084
Batch: 600, Loss: 0.05830066278576851
Batch: 700, Loss: 0.05278594419360161
Batch: 800, Loss: 0.04674055799841881
Batch: 900, Loss: 0.05848301947116852
Batch: 1000, Loss: 0.04381528124213219
Batch: 1100, Loss: 0.06324674189090729
Batch: 1200, Loss: 0.04114704951643944
Batch: 1300, Loss: 0.057159364223480225
Batch: 1400, Loss: 0.058317627757787704
Batch: 1500, Loss: 0.06413818895816803
Batch: 1600, Loss: 0.04789697378873825
Batch: 1700, Loss: 0.03897293284535408
Batch: 1800, Loss: 0.03721236810088158
Batch: 1900, Loss: 0.060143180191516876
Batch: 2000, Loss: 0.03714115172624588
Batch: 2100, Loss: 0.04689726233482361
Batch: 2200, Loss: 0.049595024436712265
Batch: 2300, Loss: 0.048318155109882355
Batch: 2400, Loss: 0.06195338815450668
Batch: 2500, Loss: 0.04466953128576

In [None]:
visualize(actor, planner)

NameError: name 'visualize' is not defined

In [None]:
    from google.colab import runtime
    runtime.unassign()

In [None]:
!pip install magiccube

Collecting magiccube
  Downloading magiccube-1.0.0-py3-none-any.whl.metadata (3.9 kB)
Downloading magiccube-1.0.0-py3-none-any.whl (16 kB)
Installing collected packages: magiccube
Successfully installed magiccube-1.0.0


In [None]:
import magiccube
import copy
from magiccube.cube_base import Color, Face
from magiccube.cube_move import CubeMove
from magiccube.solver.basic.basic_solver import BasicSolver

cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

def get_face_state(cube, face):
    array_values = np.array([[color.value for color in row] for row in cube.get_face(face)])
    tensor = torch.tensor(array_values.flatten(), dtype=torch.int64)
    return torch.nn.functional.one_hot(tensor, num_classes=6).flatten()

#state space
def get_cube_state(cube):
    return torch.stack([get_face_state(cube, Face.L), get_face_state(cube, Face.R), get_face_state(cube, Face.D), get_face_state(cube, Face.U), get_face_state(cube, Face.B), get_face_state(cube, Face.F)], dim=0)

def batch_cube_state(cube_list):
    current_states = []

    for cube in cube_list:
      current_states.append(get_cube_state(cube))

    current_states = torch.stack(current_states)

    return current_states.view(current_states.size(0), -1)

def batch_apply_action(cube_list, action_list):
  for i in range(len(cube_list)):
    cube_list[i]._rotate_once(action_list[i])

  return cube_list

#action space
movements = ["L", "L'", "L2", "R", "R'", "R2", "D", "D'", "D2", "U", "U'", "U2", "B", "B'", "B2", "F", "F'", "F2"]
reversals = ["L'", "L", "L2", "R'", "R", "R2", "D'", "D", "D2", "U'", "U", "U2", "B'", "B", "B2", "F'", "F", "F2"]
reverse_index = {0: 1, 1: 0, 2: 2, 3: 4, 4: 3, 5: 5, 6: 7, 7: 6, 8: 8, 9: 10, 10: 9, 11: 11, 12: 13, 13: 12, 14: 14, 15: 16, 16:15, 17:17}
reversals = [CubeMove.create(move_str) for move_str in reversals]
movements = [CubeMove.create(move_str) for move_str in movements]
print(reversals)

cube._rotate_once(movements[8])
solver = BasicSolver(cube)
cube_copy = copy.deepcopy(cube)
solver.solve()

for i in range(18):
  cube._rotate_once(movements[i])
  cube._rotate_once(movements[reverse_index[i]])
  print(cube.is_done())

[L', L, L2, R', R, R2, D', D, D2, U', U, U2, B', B, B2, F', F, F2]
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True
True


In [None]:
# random move at every step for dataset creation
import torch
import numpy as np
import random

action_dim = 18
state_dim = 54 * 6
num_samples = 100000

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
for i in range(num_samples):
  if i % 10000 == 0:
    print(f"Sample: {i}")
  state = get_cube_state(cube).flatten()
  data_raw[i, :state_dim] = state

  action = random.choice(range(action_dim))
  data_raw[i, state_dim + action] = 1

  cube._rotate_once(movements[action])

Sample: 0
Sample: 10000
Sample: 20000
Sample: 30000
Sample: 40000
Sample: 50000
Sample: 60000
Sample: 70000
Sample: 80000
Sample: 90000


In [None]:
action_dim = 18
state_dim = 54 * 6
num_samples = 100000
move_depth = 4  # Number of forward moves before reversing

data_raw = torch.zeros((num_samples, action_dim + state_dim))
cube = magiccube.Cube(3, "YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")

i = 0
while i < num_samples:
  forward_actions = []
  for _ in range(move_depth):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    action = random.choice(range(action_dim))
    data_raw[i, state_dim + action] = 1
    cube._rotate_once(movements[action])
    forward_actions.append(action)

    i += 1
    if i >= num_samples:
      break

  for action in reversed(forward_actions):
    if i % 10000 == 0:
        print(f"Sample: {i}")
    state = get_cube_state(cube).flatten()
    data_raw[i, :state_dim] = state

    reverse_action = reverse_index[action]
    data_raw[i, state_dim + reverse_action] = 1
    cube._rotate_once(movements[reverse_action])

    i += 1
    if i >= num_samples:
      break
  if not cube.is_done():
    print("NOT DONE")

Sample: 0
Sample: 10000
Sample: 20000
Sample: 30000
Sample: 40000
Sample: 50000
Sample: 60000
Sample: 70000
Sample: 80000
Sample: 90000


In [None]:
# Data only contains cubes <= 4 actions away from the goal state
import torch
import numpy as np
import random

action_dim = 18
state_dim = 54 * 6
num_samples = 100

data_raw = torch.zeros((num_samples, action_dim + state_dim))
#list to store the actual cube objects
cube_objects = []
cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
for i in range(num_samples):
  # print("cube ", i)
  # print(cube)
  if i % 10000 == 0:
    print(f"Sample: {i}")
  if i % 4 == 0:
    cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
  state = get_cube_state(cube).flatten()
  data_raw[i, :state_dim] = state

  action = random.choice(range(action_dim))
  data_raw[i, state_dim + action] = 1
  cube_objects.append(copy.deepcopy(cube))

  cube._rotate_once(movements[action])

Sample: 0


In [None]:
#Getting Data

class Dataset:
    def __init__(self, data):
        self.data_list = data

    def sample_batch(self, batch_size):
        start_indices = np.random.randint(0, len(self.data_list) - 32, size=batch_size)[:, np.newaxis]
        indices = start_indices + np.arange(32)[np.newaxis, :]

        sample = self.data_list[indices]
        return sample

data = Dataset(data_raw)

In [None]:
print(data_raw[0])
print(data_raw.shape)
print(cube_objects[0])

tensor([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 1., 0., 0.,
        0., 0., 0., 0., 1., 0., 0., 0., 

In [None]:
latent_dim = 32

encoder = Encoder(state_dim + action_dim, layer_size=256, latent_dim=latent_dim).to(device)
planner = Planner(state_dim, state_dim, layer_size=512, latent_dim=latent_dim).to(device)
actor = Actor(state_dim, action_dim, state_dim, layer_size=512, latent_dim=latent_dim).to(device)

encoder_optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
planner_optimizer = optim.Adam(planner.parameters(), lr=1e-4)
actor_optimizer = optim.Adam(actor.parameters(), lr=3e-4)

In [None]:
import torch.distributions as dist
import torch.optim as optim

def train_sample(batch_size, beta, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer):
    sample = data.sample_batch(batch_size).to(device)
    current_state = sample[:, 0, :-18]
    current_action = sample[:, 0, -18:]
    goal_state = sample[:, -1, :-18]
    goal_action = sample[:, -1, -18:]

    z, mu_phi, sigma_phi = encoder.forward(sample)
    mu_psi, sigma_psi = planner.forward(current_state, goal_state)

    phi_gaussian = dist.Normal(mu_phi, sigma_phi)

    psi_gaussian = dist.Normal(mu_psi, sigma_psi)

    KL_loss = torch.sum(dist.kl.kl_divergence(phi_gaussian, psi_gaussian))

    policy_action, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

    action_loss = F.cross_entropy(policy_action.squeeze(1), current_action)

    loss = beta * KL_loss + action_loss

    encoder_optimizer.zero_grad()
    planner_optimizer.zero_grad()
    actor_optimizer.zero_grad()

    loss.backward()

    encoder_optimizer.step()
    planner_optimizer.step()
    actor_optimizer.step()
    return loss

In [None]:
for batch in range(10000):
    loss = train_sample(32, .9, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer)
    if batch % 100 == 0:
        print(f"Batch: {batch}, Loss: {loss}")

Batch: 0, Loss: 6.241903305053711
Batch: 100, Loss: 2.429344892501831
Batch: 200, Loss: 2.1109321117401123
Batch: 300, Loss: 1.9098554849624634
Batch: 400, Loss: 1.8812134265899658
Batch: 500, Loss: 2.522042751312256
Batch: 600, Loss: 2.55255126953125
Batch: 700, Loss: 2.1031744480133057
Batch: 800, Loss: 2.155884027481079
Batch: 900, Loss: 2.1561708450317383
Batch: 1000, Loss: 2.158642292022705
Batch: 1100, Loss: 2.298733711242676
Batch: 1200, Loss: 2.0353195667266846
Batch: 1300, Loss: 1.9320507049560547
Batch: 1400, Loss: 1.8095418214797974
Batch: 1500, Loss: 1.973937749862671
Batch: 1600, Loss: 2.283557176589966
Batch: 1700, Loss: 2.1172311305999756
Batch: 1800, Loss: 1.972545862197876
Batch: 1900, Loss: 2.67281436920166
Batch: 2000, Loss: 1.9489545822143555
Batch: 2100, Loss: 2.2429769039154053
Batch: 2200, Loss: 2.546633243560791
Batch: 2300, Loss: 1.8986903429031372
Batch: 2400, Loss: 2.199829578399658
Batch: 2500, Loss: 1.777984619140625
Batch: 2600, Loss: 2.1823060512542725
Ba

In [None]:
import torch.distributions as dist
import torch.optim as optim

move_dict =  {"L": 0,
              "L'": 1,
              "L2": 2,
              "R": 3,
              "R'": 4,
              "R2": 5,
              "D": 6,
              "D'": 7,
              "D2": 8,
              "U": 9,
              "U'": 10,
              "U2": 11,
              "B": 12,
              "B'": 13,
              "B2": 14,
              "F": 15,
              "F'": 16,
              "F2": 17
              }

def generate_training_data(set_size, num_moves):
    cubes = []
    actions = []
    states = []
    data = []

    for i in range(set_size):
      cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
      history = cube.scramble(num_moves)
      to_solve = copy.deepcopy(cube)
      to_get_states = copy.deepcopy(cube)

      solver = BasicSolver(to_solve)
      solve = solver.solve()

      raw_actions = torch.zeros((len(solve), 18))

      raw_states = [get_cube_state(cube)]
      # instead of having states and actions separate, they should be hstacked into one tensor
      for j in range(len(solve)):
        to_get_states._rotate_once(solve[j])
        raw_states.append(get_cube_state(to_get_states).flatten())
        raw_actions[j, move_dict[str(solve[j])]] = 1

      for sample, action in zip(raw_states, raw_actions):
        sample = sample.view(-1)
        action = action.view(-1)
        data.append(torch.hstack((action, sample)))

      cubes.append(cube)
      actions.append(raw_actions)
      states.append(raw_states)

    actions = torch.stack(actions)
    states = torch.stack(states)

    stacked = torch.hstack((actions, states))

    return cubes, data


def train_sample(batch_size, beta, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer):
    #can use this logic to get the elements we need
    # iterate so that we train for each step in the soling process

    # sample = data.sample_batch(batch_size).to(device)
    # current_state = sample[:, 0, :-18]
    # current_action = sample[:, 0, -18:]
    # goal_state = sample[:, -1, :-18]
    # goal_action = sample[:, -1, -18:]

    cubes, data = generate_training_data(batch_size, 1)
    data_tensor = torch.stack(data).to(device)
    states = data_tensor[:, :, :-18]
    actions = data_tensor[:, :, -18:]

    goal_cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
    goal_state = get_cube_state(goal_cube)
    goal_state = goal_state.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
    goal_state = goal_state.view(goal_state.size(0), -1)

    print(states[0][0].flatten())


    current_state = states[:, 0]
    current_action = actions[:, 0]

    sample = torch.hstack(states, actions)


    z, mu_phi, sigma_phi = encoder.forward(sample)
    mu_psi, sigma_psi = planner.forward(current_state, goal_state)

    phi_gaussian = dist.Normal(mu_phi, sigma_phi)

    psi_gaussian = dist.Normal(mu_psi, sigma_psi)

    KL_loss = torch.sum(dist.kl.kl_divergence(phi_gaussian, psi_gaussian))

    policy_action, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

    action_loss = F.cross_entropy(policy_action.squeeze(1), current_action)

    loss = beta * KL_loss + action_loss

    encoder_optimizer.zero_grad()
    planner_optimizer.zero_grad()
    actor_optimizer.zero_grad()

    loss.backward()

    encoder_optimizer.step()
    planner_optimizer.step()
    actor_optimizer.step()
    return loss

In [None]:
for batch in range(10000):
    loss = train_sample(32, .9, encoder, actor, planner, encoder_optimizer, actor_optimizer, planner_optimizer)
    if batch % 100 == 0:
        print(f"Batch: {batch}, Loss: {loss}")

RuntimeError: stack expects each tensor to be equal size, but got [1, 18] at entry 0 and [109, 18] at entry 12

In [None]:
def test_batch(batch_size):

    cubes = []
    histories = []

    for i in range(batch_size):
      cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
      history = cube.scramble(1)

      cubes.append(cube)
      histories.append(history)

    return cubes, histories

def test_sample(batch_size, encoder, actor, planner):

    goal_cube = magiccube.Cube(3,"YYYYYYYYYRRRRRRRRRGGGGGGGGGOOOOOOOOOBBBBBBBBBWWWWWWWWW")
    goal_state = get_cube_state(goal_cube)
    goal_state = goal_state.unsqueeze(0).repeat(batch_size, 1, 1).to(device)
    goal_state = goal_state.view(goal_state.size(0), -1)

    cubes, histories = test_batch(batch_size)

    solved = [False] * batch_size
    steps_taken = [0] * batch_size

    with torch.no_grad():
      current_state = batch_cube_state(cubes).to(device)

      mu_psi, sigma_psi = planner.forward(current_state.float(), goal_state.float())
      z = torch.normal(mu_psi, sigma_psi)
      actor_dist, _ = actor.forward(current_state.unsqueeze(1), z.unsqueeze(1), goal_state.unsqueeze(1))

      best_actions = torch.argmax(actor_dist, -1)

      #evaluate
      for i, action_index in enumerate(best_actions):
        if not solved[i]:
          cubes[i]._rotate_once(movements[action_index])
          steps_taken[i] += 1
          if cubes[i].is_done():
            solved[i] = True


    num_successful = sum(solved)
    print("Number of successful solves: ", num_successful)



In [None]:
test_sample(32, encoder, actor, planner)

Number of successful solves:  32


False
