In [1]:
import os
import sys

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.optim as optim
import torch.nn.functional as F

from environments.composition import CompositionGrid
from models.embedding_model import LearnableEmbedding
from models.action_network import LowerPolicyTrainer
from environments.pomdp_config import *
from models.abstract_state_network import AbstractStateNetwork, AbstractStateDataGenerator

In [2]:
###################
# CONSTANTS
###################

device = torch.device("mps")
# Data Generation Constants
COMPOSITION_CONFIG = composite_config2
BASE_CONFIGS = [c1, c2]
NUM_SAMPLES = 100


HYPER_EPOCHS = 50
BATCH_SIZE = 1
POLICY_TIMESTEPS = 20
INFERENCE_TIMESTEPS = 12
LOWER_STATE_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/state_network/jan_4_run_1_embedding.state"
LOWER_ACTION_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/action_network/dec_6_run_1_embedding.state"
HIGHER_STATE_MODEL_PATH = "/Users/vsathish/Documents/Quals/saved_models/state_network/jan_20_higher_state_"+COMPOSITION_CONFIG["name"]+".state"


# Define env
env = CompositionGrid(COMPOSITION_CONFIG)
env.plot_board(name="composition2")


5 5 13 5
In mapping function :  1 3


In [3]:
# Define lower state model
lower_state_model = LearnableEmbedding(device, BATCH_SIZE, timesteps=INFERENCE_TIMESTEPS).to(device)
try:
    lower_state_model.load_state_dict(torch.load(LOWER_STATE_MODEL_PATH))
    print("################## LOAD SUCCESS #################")
except:
    print("################## NOPE #######################")

# Define lower action model
lower_action_model = LowerPolicyTrainer(device, BATCH_SIZE, POLICY_TIMESTEPS).to(device)
try:
    lower_action_model.load_state_dict(torch.load(LOWER_ACTION_MODEL_PATH))
except:
    print("COULD NOT LOAD ACTION NETWORK")

# Define higher state model
higher_state_model = AbstractStateNetwork(4, 16, COMPOSITION_CONFIG["num_blocks"]).to(device)
try:
    higher_state_model.load_state_dict(torch.load(HIGHER_STATE_MODEL_PATH))
except:
    print("COULD NOT LOAD HIGHER STATE NETWORK")


################## LOAD SUCCESS #################
COULD NOT LOAD HIGHER STATE NETWORK


In [4]:
# Define data generator
data_gen = AbstractStateDataGenerator(COMPOSITION_CONFIG, BASE_CONFIGS, lower_state_model, lower_action_model, device)
x_train, y_train = data_gen.generate_data(NUM_SAMPLES, env)

############## SAMPLE 0 ##############
################## BEGIN ABS STATE ##################
init state :  (1, 10)
Next state after taking action  0  is  (0, 10)
Action  0  leads to wall. Hence staying at  (0, 10)
Action  0  leads to wall. Hence staying at  (0, 10)
Next state after taking action  1  is  (1, 10)
Next state after taking action  1  is  (2, 10)
Next state after taking action  1  is  (3, 10)
Next state after taking action  2  is  (3, 9)
Action  2  leads to wall. Hence staying at  (3, 9)
Action  2  leads to wall. Hence staying at  (3, 9)
Next state after taking action  3  is  (3, 10)
Action  3  leads to wall. Hence staying at  (3, 10)
Action  3  leads to wall. Hence staying at  (3, 10)
torch.Size([12, 1, 13]) torch.Size([12, 1, 9])
in forward :  torch.Size([12, 1, 13]) torch.Size([12, 1, 32])
Center index :  1
Final state :  (1, 10)
higher state :  [-0.02549605  0.0482652  -0.03164882  0.02881731  0.          0.
  1.        ]
################## DONE WITH ABS STATE ##########

In [5]:
print("X SHAPE: ", x_train.shape, "Y SHAPE: ", y_train.shape)

X SHAPE:  torch.Size([100, 23]) Y SHAPE:  torch.Size([100, 7])


In [9]:
# Define the optimizer
optimizer = optim.Adam(higher_state_model.parameters(), lr=0.0001)

# Define the loss function
loss_fn = F.mse_loss

# Define the number of epochs
num_epochs = 100

# Training loop
for epoch in range(num_epochs):
    # Set the model to training mode
    higher_state_model.train()
    
    # Reset the optimizer
    optimizer.zero_grad()
    
    # Forward pass
    outputs = higher_state_model(x_train)
    
    # Compute the loss
    loss = loss_fn(outputs, y_train)
    
    # Backward pass
    loss.backward()
    
    # Update the weights
    optimizer.step()
    
    # Print the loss for each epoch
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
    
# Save the trained model to HIGHER_STATE_MODEL_PATH
torch.save(higher_state_model.state_dict(), HIGHER_STATE_MODEL_PATH)


Epoch 1/100, Loss: 0.0037914016284048557
Epoch 2/100, Loss: 0.003812347073107958
Epoch 3/100, Loss: 0.003794345771893859
Epoch 4/100, Loss: 0.0038000706117600203
Epoch 5/100, Loss: 0.003802367253229022
Epoch 6/100, Loss: 0.0037986969109624624
Epoch 7/100, Loss: 0.0037945671938359737
Epoch 8/100, Loss: 0.0037932211998850107
Epoch 9/100, Loss: 0.0037948598619550467
Epoch 10/100, Loss: 0.003796657547354698
Epoch 11/100, Loss: 0.0037960815243422985
Epoch 12/100, Loss: 0.0037938738241791725
Epoch 13/100, Loss: 0.0037923294585198164
Epoch 14/100, Loss: 0.003792570671066642
Epoch 15/100, Loss: 0.0037937164306640625
Epoch 16/100, Loss: 0.003794183721765876
Epoch 17/100, Loss: 0.0037934724241495132
Epoch 18/100, Loss: 0.003792409086599946
Epoch 19/100, Loss: 0.003791988827288151
Epoch 20/100, Loss: 0.00379230291582644
Epoch 21/100, Loss: 0.0037927059456706047
Epoch 22/100, Loss: 0.0037927369121462107
Epoch 23/100, Loss: 0.003792430739849806
Epoch 24/100, Loss: 0.0037920065224170685
Epoch 25/100