In [1]:
from tools import *

In [2]:
class MAML():
    def __init__(self, output_dir, meta_iter=200, train_batch_size=20, test_batch_size=40):
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        
        self.meta_iter = meta_iter
        self.train_batch_size = train_batch_size
        self.test_batch_size = test_batch_size
        
        self.history = []
        self.train_loss = []
        self.train_return = []
        self.test_acc = []
        
        self.train_sampler = Sampler('RandomMiniEnv', meta_iter=meta_iter, 
                                     batch_size=train_batch_size, device=self.device)
        self.test_sampler = Sampler('RandomMiniEnv', meta_iter=meta_iter,
                                    batch_size=test_batch_size, device=self.device)
        
        self.policy = Policy()
        self.baseline = LinearFeatureBaseline(input_size=135)
        self.ml = MetaLearner(self.train_sampler, self.policy, 
                              self.baseline, num_episodes=train_batch_size, device=self.device)
        
        os.makedirs(output_dir, exist_ok=True)
        self.checkpoint_path = os.path.join(output_dir, 
                                            "checkpoint.pth.tar")
        self.config_path = os.path.join(output_dir, "config.txt")
        
        locs = {k : v for k, v in locals().items() if k is not 'self'}
        self.__dict__.update(locs)
        
        if os.path.isfile(self.config_path):
            with open(self.config_path, 'r') as f:
                #if f.read()[:-1] != repr(self):
                    print(f.read())
                #    raise ValueError(
                #        "Cannot create this experiment: "
                #        "I found a checkpoint conflicting with the current setting.")
            self.load()
        else:
            self.save()
        
    @property
    def iteration(self):
        return len(self.history)
    
    def setting(self):
        return {'Policy' : self.policy,
                'Baseline' : self.baseline,
                'TrainBatchSize' : self.train_batch_size}  
    
    def __repr__(self):
        string = ''
        for key, val in self.setting().items():
            string += '{}({})\n'.format(key, val)
        return string
    
    def state_dict(self):
        return {'Policy' : self.policy.state_dict(),
                'Baseline' : self.baseline.state_dict(),
                'TrainSampler' : self.train_sampler,
                'History' : self.history,
                'TrainLoss' : self.train_loss,
                'TrainReturn' : self.train_return}
    
    def load_state_dict(self, checkpoint):
        self.policy.load_state_dict(checkpoint['Policy'])
        self.baseline.load_state_dict(checkpoint['Baseline'])
        
        self.train_sampler = checkpoint['TrainSampler']
        self.history = checkpoint['History']
        self.train_loss = checkpoint['TrainLoss']
        self.train_return = checkpoint['TrainReturn']
        
    def save(self):
        torch.save(self.state_dict(), self.checkpoint_path)
        with open(self.config_path, 'w') as f:
            print(self, file=f)
    
    def load(self):
        checkpoint = torch.load(self.checkpoint_path,
                                map_location=self.device)
        self.load_state_dict(checkpoint)
        del checkpoint  
        
    def run(self):
        start_iter = self.iteration
        
        print("Start/Continue training from iteration {}".format(start_iter))
        
        for i in range(start_iter, self.meta_iter):
            tasks = self.train_sampler.sample_tasks(low=1, 
                                                    high=10, 
                                                    num_tasks=10)
            episodes = self.ml.sample(tasks, first_order=True)
                        
            avg_return = self.ml.average_return(episodes)
            loss = self.ml.step(episodes)
            
            self.history.append(i)
            self.train_return.append(avg_return)
            self.train_loss.append(loss)
            
            self.save()
            
            print("Done with meta-iteration {}. Avg Return = {}, Loss = {}".format(i,
                                                                               avg_return,
                                                                               loss))
            
        print("Finished training for {} meta-iterations".format(self.meta_iter))

In [3]:
exp = MAML(output_dir="experiment2")

Policy(Policy(
  (fc1): Linear(in_features=135, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=100, bias=True)
  (fc3): Linear(in_features=100, out_features=2, bias=True)
  (tanh): Tanh()
))
Baseline(LinearFeatureBaseline(
  (linear): Linear(in_features=274, out_features=1, bias=False)
))
TrainBatchSize(20)




In [5]:
def create_episodes(task, policy, params):        
    episodes = BatchEpisodes(batch_size=10, device=exp.device)

    traj_id = 0

    done = False
    state = task.reset()

    while not done:
        with torch.no_grad():
            action = policy(torch.Tensor(state).to(device=exp.device), 
                            params=params).sample()
            action = action.cpu().numpy()

        next_state, reward, done, _ = task.step(action)
        episodes.append(next_state, action, reward, traj_id)

        state = next_state

        if done:
            traj_id += 1

            if traj_id == 10:
                return episodes

            done = False
            state = task.reset()

In [6]:
# task = exp.test_sampler.sample_tasks(low=1, high=10, num_tasks=1)[0]

task_seed = 5
task = exp.test_sampler.sample_tasks(low=task_seed, high=task_seed+1, num_tasks=1)[0]

In [7]:
test_episodes = [(0, create_episodes(task, exp.ml.policy, params=None), None)]

for i in range(2):
    _, episodes_prev, _ = test_episodes[-1]
    
    params = exp.ml.adapt(episodes_prev)
    episodes_next = create_episodes(task, exp.ml.policy, params=params)
    
    test_episodes.append((i+1, episodes_next, params))

In [8]:
final_params = test_episodes[-1][2]

In [15]:
state = task.reset()
done = False

task.render()

In [14]:
r = 0
which_one = 1

while not done:
    
    if which_one == 1:
        action = exp.ml.policy(torch.Tensor(state).to(exp.device), 
                              params=final_params)
        action = action.loc.cpu().detach().numpy()

        next_state, reward, done, _ = task.step(action)
    else:
        next_state, reward, done, _ = task.step(task.action_space.sample())
    
    r += reward
    
    print(reward)
        
    task.render()
    
    state = next_state

-1.5761158700615971
-1.5406282942550993
-1.5119582940422394
-1.609404077613262
-1.524193681834201
-1.495800771470192
-1.4435235987751454
-1.427888525597278
-1.3401237200003058
-1.3591767113734505
-1.342942467005984
-1.2651055120510397
-1.2250678635378864
-1.1781794579282514
-1.2008261828367142
-1.1438650742836258
-1.1046566978530075
-1.061244558369876
-1.0539644325755275
-1.0328029744622813
-1.0073963792124578
-0.9961529588997886
-0.9622234566230083
-0.9854127232912243
-0.9850130438222044
-0.977671467998085
-0.9668816221287406
-0.9561388225984209
-0.9489430021444315
-0.9365147512592006
-0.9188489395862028
-0.9125948863019933
-0.9089731505644062
-0.8983341182960572
-0.898221783772676
-0.887583038598095
-0.8813303350261429
-0.8777090378390043
-0.8644384746006922
-0.8502908512010322
-0.843160482856979
-0.8395390090114423
-0.832408555942688
-0.8287872574770202
-0.8286750217603922
-0.8250543526127185
-0.8214335466278041
-0.8116724602506619
-0.8089294048579105
-0.805309606473951
-0.798180897