In [1]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn.functional as F

from environments.composition import CompositionGrid
from models.embedding_model import LearnableEmbedding
from models.action_network import LowerPolicyTrainer
from environments.pomdp_config import *
from models.abstract_state_network import AbstractStateNetwork, AbstractStateDataGenerator

In [2]:
###################
# CONSTANTS
###################

device = torch.device("mps")
# Data Generation Constants
COMPOSITION_CONFIG = composite_config2
BASE_CONFIGS = [c1, c2]
NUM_SAMPLES = 100


HYPER_EPOCHS = 50
BATCH_SIZE = 1
POLICY_TIMESTEPS = 20
INFERENCE_TIMESTEPS = 12
LOWER_STATE_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/state_network/jan_4_run_1_embedding.state"
LOWER_ACTION_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/action_network/dec_6_run_1_embedding.state"
HIGHER_STATE_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/state_network/jan_4_higher_state_"+COMPOSITION_CONFIG["name"]+".state"


# Define env
env = CompositionGrid(COMPOSITION_CONFIG)
env.plot_board(name="composition2")


5 5 13 5
In mapping function :  1 3


In [3]:
# Define lower state model
lower_state_model = LearnableEmbedding(device, BATCH_SIZE, timesteps=INFERENCE_TIMESTEPS).to(device)
try:
    lower_state_model.load_state_dict(torch.load(LOWER_STATE_MODEL_PATH))
    print("################## LOAD SUCCESS #################")
except:
    print("################## NOPE #######################")

# Define lower action model
lower_action_model = LowerPolicyTrainer(device, BATCH_SIZE, POLICY_TIMESTEPS).to(device)
try:
    lower_action_model.load_state_dict(torch.load(LOWER_ACTION_MODEL_PATH))
except:
    print("COULD NOT LOAD ACTION NETWORK")

# Define higher state model
higher_state_model = AbstractStateNetwork(4, 16, COMPOSITION_CONFIG["num_blocks"]).to(device)
try:
    higher_state_model.load_state_dict(torch.load(HIGHER_STATE_MODEL_PATH))
except:
    print("COULD NOT LOAD HIGHER STATE NETWORK")


################## LOAD SUCCESS #################


In [4]:
# Define data generator
data_gen = AbstractStateDataGenerator(COMPOSITION_CONFIG, BASE_CONFIGS, lower_state_model, lower_action_model, device)
x_train, y_train = data_gen.generate_data(NUM_SAMPLES)

5 5 13 5
In mapping function :  1 3
################## BEGIN ABS STATE ##################
init state :  (1, 9)
Next state after taking action  0  is  (0, 9)
Action  0  leads to wall. Hence staying at  (0, 9)
Action  0  leads to wall. Hence staying at  (0, 9)
Next state after taking action  1  is  (1, 9)
Next state after taking action  1  is  (2, 9)
Next state after taking action  1  is  (3, 9)
Action  2  leads to wall. Hence staying at  (3, 9)
Action  2  leads to wall. Hence staying at  (3, 9)
Action  2  leads to wall. Hence staying at  (3, 9)
Next state after taking action  3  is  (3, 10)
Action  3  leads to wall. Hence staying at  (3, 10)
Action  3  leads to wall. Hence staying at  (3, 10)
in forward :  torch.Size([12, 1, 13]) torch.Size([12, 1, 32])
Center index :  1
Final state :  (1, 9)
higher state :  [-0.02549605  0.0482652  -0.03164882  0.02881731  0.          0.
  1.        ]
################## DONE WITH ABS STATE ##################
Running policy for subgoal :  2 (0, 2)
State

In [5]:
print("X SHAPE: ", x_train.shape, "Y SHAPE: ", y_train.shape)

X SHAPE:  torch.Size([100, 23]) Y SHAPE:  torch.Size([100, 7])


In [6]:
# Define the optimizer
optimizer = optim.Adam(higher_state_model.parameters())

# Define the loss function
loss_fn = F.mse_loss

# Define the number of epochs
num_epochs = 50

# Training loop
for epoch in range(num_epochs):
    # Set the model to training mode
    higher_state_model.train()
    
    # Reset the optimizer
    optimizer.zero_grad()
    
    # Forward pass
    outputs = higher_state_model(x_train)
    
    # Compute the loss
    loss = loss_fn(outputs, y_train)
    
    # Backward pass
    loss.backward()
    
    # Update the weights
    optimizer.step()
    
    # Print the loss for each epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
    
# Save the trained model to HIGHER_STATE_MODEL_PATH
torch.save(higher_state_model.state_dict(), HIGHER_STATE_MODEL_PATH)


Epoch 1/50, Loss: 0.07075288891792297
Epoch 2/50, Loss: 0.07156673073768616
Epoch 3/50, Loss: 0.07062701135873795
Epoch 4/50, Loss: 0.07067577540874481
Epoch 5/50, Loss: 0.07066217809915543
Epoch 6/50, Loss: 0.07046829909086227
Epoch 7/50, Loss: 0.07033491134643555
Epoch 8/50, Loss: 0.07031524926424026
Epoch 9/50, Loss: 0.07033121585845947
Epoch 10/50, Loss: 0.07030526548624039
Epoch 11/50, Loss: 0.07022245973348618
Epoch 12/50, Loss: 0.07015134394168854
Epoch 13/50, Loss: 0.07015864551067352
Epoch 14/50, Loss: 0.0701887235045433
Epoch 15/50, Loss: 0.07014196366071701
Epoch 16/50, Loss: 0.07006939500570297
Epoch 17/50, Loss: 0.07006629556417465
Epoch 18/50, Loss: 0.07010640949010849
Epoch 19/50, Loss: 0.07011041045188904
Epoch 20/50, Loss: 0.07006539404392242
Epoch 21/50, Loss: 0.07003417611122131
Epoch 22/50, Loss: 0.07004876434803009
Epoch 23/50, Loss: 0.07006671279668808
Epoch 24/50, Loss: 0.07005142420530319
Epoch 25/50, Loss: 0.07002617418766022
Epoch 26/50, Loss: 0.07002305239439