In [33]:
import os 
import torch
import gym
import numpy as np
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from matplotlib import pyplot as plt
from IPython import display
plt.style.use("ggplot")
from installing_a_printer.utils import load_demos

In [34]:
from mini_behavior.envs import *

In [35]:
task_name = 'SimpleInstallingAPrinter'
# env_name = f'MiniGrid-{task_name}-16x16-N2-v1'
env_name = f'MiniGrid-{task_name}-8x8-N2-v0'

env = gym.make(env_name)

In [36]:
action_space_size = env.action_space.n
state_space_size  = 5

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("device", device)

split = False

device cpu


In [37]:
def get_split_idx(num_samples, split=0.8):
    np.random.seed(0)
    train_idx = np.random.choice(num_samples, int(split * num_samples))
    test_idx = [i for i in range(num_samples) if i not in train_idx]

    return train_idx, test_idx


class DemoDataset(Dataset):
    def __init__(self, demo, idxs=None):
        self.demo = demo

        # get states and actions
        states = []
        actions = []
        for state, action in self.demo:
            states.append(state)
            actions.append(action.value)

        idxs = [i for i in range(len(self.demo))] if idxs is None else idxs
        
        states = torch.tensor(states)
        self.states = states[idxs]
        
        actions = torch.tensor(actions)
        self.actions = actions[idxs]

    def __len__(self):
        return len(self.demo)

    def __getitem__(self, idx):
        state = self.states[idx]
        action = self.actions[idx]

        return state, action

In [38]:
# load all demonstrations
demo_dir = '/Users/emilyjin/Code/behavior/mini_behavior/installing_a_printer/demo_8'
demos = load_demos(demo_dir) # list of (state, action) tuples

def get_dataloaders(split=None, batch_size=32):
    if split:
        train_idxs, test_idxs = get_split_idx(len(demos), split)
        train_dataset = DemoDataset(demos, train_idxs)
        test_dataset = DemoDataset(demos, test_idxs)
    else:
        train_dataset = DemoDataset(demos)
        test_dataset = train_dataset


    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset)
    
    return train_dataloader, test_dataloader

In [42]:
def get_policy(small):
    if small:
        mlp_policy = nn.Sequential(
            nn.Linear(state_space_size, 32),
            nn.ReLU(),

            nn.Linear(32, 64),
            nn.ReLU(),

            nn.Linear(64, 100),
            nn.ReLU(),

            nn.Linear(100, 64),
            nn.ReLU(),

            nn.Linear(64, 32),
            nn.ReLU(),

            nn.Linear(32, action_space_size),
            # nn.Softmax()
        )
    
    else:
        mlp_policy = nn.Sequential(
            nn.Linear(state_space_size, 32),
            nn.ReLU(),

            nn.Linear(32, 64),
            nn.ReLU(),

            nn.Linear(64, 100),
            nn.ReLU(),

            nn.Linear(100, 256),
            nn.ReLU(),


            nn.Linear(256, 100),
            nn.ReLU(),


            nn.Linear(100, 64),
            nn.ReLU(),

            nn.Linear(64, 32),
            nn.ReLU(),

            nn.Linear(32, action_space_size),
            # nn.Softmax()
        )
    
    return mlp_policy

def get_criterion():
    return nn.CrossEntropyLoss()

In [43]:
def train_epochs(train_dataloader, lr, max_epochs, model_size):
    losses = []

    policy = get_policy(model_size)
    criterion = get_criterion()
    
    optimizer = torch.optim.Adam(policy.parameters(), lr=lr)
            
    for epoch in range(max_epoch):
        total_loss = 0
        for s, a in train_dataloader:
            # transfer to device
            s, a = s.to(device), a.to(device)
            criterion.zero_grad()

            # model computations
            a_pred = policy(s.type(torch.float))
            loss = criterion(a_pred, a)
            total_loss += loss.item()
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        print("[EPOCH]: %i, [CE LOSS]: %.6f" % (epoch+1, total_loss / len(train_dataloader)))

        display.clear_output(wait=True)

        losses.append(total_loss / len(train_dataloader))
        
        if epoch % 50 == 0:
            save_model_and_plot(policy, losses, lr, batch_size, epoch, model_dir)
            
    return losses, policy, criterion
    
    
def test_policy(test_dataloader, policy, criterion):
    with torch.set_grad_enabled(False):
        correct = 0
        total_loss = 0
        for s, a in test_dataloader:
            s, a = s.to(device), a.to(device)
            a_pred = policy(s.type(torch.float))
            loss = criterion(a_pred, a)
            total_loss += loss.item()

            if a == np.argmax(a_pred):
                correct += 1

    avg_loss = total_loss / len(test_dataloader)
    accuracy = correct / len(test_dataloader)
    
    print(f'TEST [CE LOSS]: {avg_loss}')
    print(f'TEST [ACCURACY]: {accuracy}')
    
    return avg_loss, accuracy


def save_model_and_plot(policy, losses, lr, batch_size, epochs, model_dir):
    # save plot
    model_filename = f"lr={lr}_batch={batch_size}_epochs={epochs}"
    model_path = os.path.join(model_dir, model_filename)

    torch.save(policy, model_path)

#     # save plot
#     loss_plots_dir = '/Users/emilyjin/Code/behavior/mini_behavior/mini_behavior/loss_plots/loss_plots_8'
#     loss_plot_path = os.path.join(loss_plots_dir, f'{model_filename}.png')

#     plt.plot(losses, label='train loss')
#     plt.xlabel("num epochs")
#     plt.ylabel("ce loss")
#     plt.show()
#     plt.savefig(loss_plot_path)

In [41]:
split = None
test_losses = {}
test_accuracies = {}


# hyperparam tuning
max_epochs = [100, 150, 200]
lrs = [5e-4, 2.5e-4, 1e-4]
batch_sizes = [64, 128]

for batch_size in batch_sizes:
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(split, batch_size)
    
    for max_epoch in max_epochs:
        for lr in lrs:
            losses, policy, criterion = train_epochs(train_dataloader, lr, max_epoch, True)
            avg_loss, accuracy = test_policy(test_dataloader, policy, criterion)
            
            model_dir = "/Users/emilyjin/Code/behavior/mini_behavior/installing_a_printer/models_8/small"
            save_model_and_plot(policy, losses, lr, batch_size, max_epoch, model_dir)
            
            params = f'lr={lr}_batch={batch_size}_epochs={max_epoch}'
            test_losses[params] = avg_loss
            test_accuracies[params] = accuracy

KeyboardInterrupt: 

In [44]:
best_loss = 1e100
best_loss_params = None
best_accuracy = -1
best_accuracy_params = None

for params in test_losses.keys():
    if test_losses[params] < best_loss:
        best_loss = test_losses[params]
        best_loss_params = params
    if test_accuracies[params] > best_accuracy:
        best_accuracy = test_accuracies[params]
        best_accuracy_params = params

print(f'best_loss')
print(f'model: {best_loss_params}')
print(f'loss: {best_loss}')
      
print(f'best_accuracy')
print(f'model: {best_accuracy_params}')
print(f'loss: {best_accuracy}')   
      
# loss_plots_dir = '/Users/emilyjin/Code/behavior/mini_behavior/mini_behavior/loss_plots'
# loss_plot_path = os.path.join(loss_plots_dir, f'{model_filename}.png')

best_loss
model: lr=0.00025_batch=64_epochs=200
loss: 0.6779619894792943
best_accuracy
model: lr=0.0001_batch=64_epochs=100
loss: 0.6927822047195643


In [None]:
# large model
split = None
large_test_losses = {}
large_test_accuracies = {}


# hyperparam tuning
max_epochs = [200]
lrs = [5e-4, 2.5e-4, 1e-4, 5e-5, 1e-5, 1e-6]
batch_sizes = [128]

for batch_size in batch_sizes:
    # get dataloaders
    train_dataloader, test_dataloader = get_dataloaders(split, batch_size)
    
    for max_epoch in max_epochs:
        for lr in lrs:
            losses, policy, criterion = train_epochs(train_dataloader, lr, max_epoch, False)
            avg_loss, accuracy = test_policy(test_dataloader, policy, criterion)
            
            model_dir = "/Users/emilyjin/Code/behavior/mini_behavior/installing_a_printer/models_8/large"
            save_model_and_plot(policy, losses, lr, batch_size, max_epoch, model_dir)
            
            params = f'lr={lr}_batch={batch_size}_epochs={max_epoch}'
            large_test_losses[params] = avg_loss
            large_test_accuracies[params] = accuracy

[EPOCH]: 9, [CE LOSS]: 1.453768


In [None]:
# large model
best_loss = 1e100
best_loss_params = None
best_accuracy = -1
best_accuracy_params = None

for params in large_test_losses.keys():
    if large_test_losses[params] < best_loss:
        best_loss = large_test_losses[params]
        best_loss_params = params
    if large_test_accuracies[params] > best_accuracy:
        best_accuracy = large_test_accuracies[params]
        best_accuracy_params = params

print(f'best_loss')
print(f'model: {best_loss_params}')
print(f'loss: {best_loss}')
      
print(f'best_accuracy')
print(f'model: {best_accuracy_params}')
print(f'loss: {best_accuracy}')   
      
# loss_plots_dir = '/Users/emilyjin/Code/behavior/mini_behavior/mini_behavior/loss_plots'
# loss_plot_path = os.path.join(loss_plots_dir, f'{model_filename}.png')