In [1]:
%reset

Once deleted, variables cannot be recovered. Proceed (y/[n])? y


In [2]:
from rl.tests.frozen_lake_policy import build_best_policy
from rl.valfuncs.model_free import gen_dataset
import gym
import torch
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import DataLoader
import os.path as path
from rl.valfuncs.hyperparams import Hyperparams

In [13]:
class ValueFunc(pl.LightningModule):
    def __init__(self, **kwargs):
        super().__init__()
        self.fc = torch.nn.Linear(1, 2)
        self.out = torch.nn.Linear(2, 1)
        self.loss_fn = torch.nn.MSELoss(reduction="mean")
        
        self.hparams = kwargs.get("hparams")
        self.trainds = kwargs.get("trainds")
        self.valds = kwargs.get("valds")
        
    @classmethod
    def for_training(cls, hparams):
        fl = gym.make("FrozenLake-v0")
        policy = build_best_policy(fl)
        trainds = gen_dataset(fl, policy, 1000)
        valds = gen_dataset(fl, policy, 100)
        return cls(hparams=hparams, trainds=trainds, valds=valds)
    
    @classmethod
    def for_inference(cls, ckpt):
        obj = cls()
        obj.load(ckpt)
    
    def forward(self, x):
        x = self.fc(x)
        x = F.relu(x)
        return self.out(x)
    
    def training_step(self, batch, batch_num):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss_fn(y_hat, y)
        metrics = {"train_loss": loss.detach()}
        return {"loss": loss, "log": metrics, "progress_bar": metrics}
    
    def validation_step(self, batch, batch_num):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss_fn(y_hat, y)
        metrics = {"val_loss": loss.detach()}
        return metrics
    
    def validation_end(self, outputs):
        avg_loss = torch.stack([output["val_loss"] for output in outputs]).mean()
        rmse = torch.sqrt(avg_loss)
        metrics = {"val_loss": avg_loss, "val_rmse": rmse}
        return {"val_loss": avg_loss, "log": metrics, "progress_bar": metrics}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=self.hparams.lr)
    
    @pl.data_loader
    def train_dataloader(self):
        return DataLoader(self.trainds, batch_size=self.hparams.batch_size, shuffle=True)
    
    @pl.data_loader
    def val_dataloader(self):
        return DataLoader(self.valds, batch_size=100)

In [None]:
tblogs = path.expanduser("~/mldata/tblogs/frozen-lake")
hparams = Hyperparams(batch_size=8, epochs=10, lr=0.01)
trainer = pl.Trainer(default_save_path=tblogs, max_nb_epochs=hparams.epochs)
model = ValueFunc.for_training(hparams)

In [None]:
trainer.fit(model)

In [14]:
v1 = path.expanduser("~/mldata/tblogs/frozen-lake/lightning_logs/version_1")
ckpt = path.join("checkpoints")


AttributeError: 'ValueFunc' object has no attribute 'load'

model = ValueFunc()

In [21]:
model.load_state_dict(v1)

AttributeError: 'str' object has no attribute 'copy'

In [17]:
model.forward(torch.Tensor[[15]])

AttributeError: 'dict' object has no attribute 'forward'

In [22]:
s = torch.load(v1)

In [23]:
s

{'epoch': 4,
 'global_step': 6249,
 'checkpoint_callback_best': 0.018623653799295425,
 'early_stop_callback_wait': 0,
 'early_stop_callback_patience': 3,
 'optimizer_states': [{'state': {5072305536: {'step': 6250,
     'exp_avg': tensor([[0.0003],
             [0.0484]]),
     'exp_avg_sq': tensor([[0.0028],
             [0.0413]])},
    5072328312: {'step': 6250,
     'exp_avg': tensor([0.0001, 0.0035]),
     'exp_avg_sq': tensor([4.2870e-05, 3.1053e-04])},
    5072328528: {'step': 6250,
     'exp_avg': tensor([[0.0055, 0.0496]]),
     'exp_avg_sq': tensor([[0.0019, 0.0348]])},
    5072327952: {'step': 6250,
     'exp_avg': tensor([0.0383]),
     'exp_avg_sq': tensor([0.0141])}},
   'param_groups': [{'lr': 0.01,
     'betas': (0.9, 0.999),
     'eps': 1e-08,
     'weight_decay': 0,
     'amsgrad': False,
     'params': [5072305536, 5072328312, 5072328528, 5072327952]}]}],
 'lr_schedulers': [],
 'state_dict': OrderedDict([('fc.weight', tensor([[-0.0652],
                       [ 0.3858