In [2]:
import torch
import torch.nn as nn
import os

In [8]:
# Find the absolute project root dynamically based on the current file location
current_file_dir = os.path.dirname(os.path.abspath(__file__)) if '__file__' in globals() else os.getcwd()
BASE_DIR = os.path.abspath(os.path.join(current_file_dir, ".."))

DATA_DIR = os.path.join(BASE_DIR, "data")
MODEL_DIR = os.path.join(BASE_DIR, "models")

file_path = os.path.join(MODEL_DIR, 'bc_model.pth')

In [9]:
if not os.path.exists(file_path):
    print(f"Warning: Model not found at: {file_path}")
else:
    # Proceed with loading the model here
    pass

In [11]:
# Define BC Model Again
class BCModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(7, 128),
            nn.ReLU(),
            nn.Linear(128, 7)
        )

    def forward(self, x):
        return self.net(x)

model = BCModel()
model.load_state_dict(torch.load(file_path))
model.eval()

BCModel(
  (net): Sequential(
    (0): Linear(in_features=7, out_features=128, bias=True)
    (1): ReLU()
    (2): Linear(in_features=128, out_features=7, bias=True)
  )
)

In [12]:
# Predict Actions for Some Random States
test_states = torch.randn(10, 7)  # 10 random test states
predicted_actions = model(test_states).detach().numpy()

print("Predicted Actions for Sample States:\n", predicted_actions)

Predicted Actions for Sample States:
 [[0.7211858  1.0027367  0.71034104 0.5160428  0.70419854 0.9232183
  0.56524724]
 [0.71278924 0.74677205 0.8672247  0.559832   0.5257377  1.2033138
  0.6450333 ]
 [0.87614655 0.5408053  0.5981786  0.46281207 0.6776964  0.8831862
  0.6028803 ]
 [0.6403762  0.54328114 1.0347357  0.5177964  0.61502635 0.7390717
  0.7818668 ]
 [1.4180121  0.14557631 1.2673876  0.5801038  0.57914084 0.6345424
  0.96067256]
 [1.212083   0.59703416 1.3690614  0.6936936  0.529771   1.6309533
  0.8855867 ]
 [1.255387   0.18826197 0.67188406 0.7753591  0.71115637 0.9236139
  0.9446333 ]
 [1.2966685  0.56889176 0.83950806 0.59042764 0.88406026 0.37953246
  0.9803912 ]
 [0.52026737 0.50386894 0.44543633 0.53065777 0.4783346  0.46486184
  0.55192554]
 [0.9196117  0.235091   0.82352775 0.68676645 0.6327492  0.7158605
  0.6584644 ]]
