In [None]:
# default_exp agent

# Agent

> Model agents

In [None]:
#hide
from nbdev.showdoc import *
%load_ext autoreload
%autoreload 2

In [None]:
# export

from mrl.imports import *
from mrl.core import *
from mrl.torch_imports import *
from mrl.torch_core import *
from mrl.callbacks import *


move latent to sampler. change latent data from `{name, idxs}` to `{name, latent_vectors}`


In [None]:
# export

class Agent(Callback):
    def __init__(self, model, loss_function, dataset, opt_kwargs={}, clip=1., name='agent'):
        super().__init__(name=name, order=2)
        
        self.model = model
        to_device(self.model)
        
        self.loss_function = loss_function
        self.dataset = dataset
        
        self.opt = self.get_opt(self.model.parameters(), **opt_kwargs)
        self.clip = clip
        
    def get_opt(self, parameters, **optim_kwargs):
        return optim.Adam(parameters, **optim_kwargs)
    
    def setup(self):
        pass
    
    def before_train(self):
        pass
    
    def build_buffer(self):
        pass
    
    def after_build_buffer(self):
        pass
    
    def before_batch(self):
        pass
    
    def sample_batch(self):
        pass
    
    def after_sample(self):
        env = self.environment
        batch_state = env.batch_state
        sequences = batch_state.samples
                
        batch_ds = self.dataset.new(sequences)
        batch = batch_ds.collate_function([batch_ds[i] for i in range(len(batch_ds))])
        batch = to_device(batch)
        bs = len(batch_ds)
        x,y = batch
            
        batch_state.x = x
        batch_state.y = y
        batch_state.bs = bs
        batch_state.rewards = to_device(torch.zeros(bs))
        batch_state.trajectory_rewards = None
    
    def get_model_outputs(self):
        pass
    
    def compute_reward(self):
        pass
    
    def after_compute_reward(self):
        pass
    
    def reward_modification(self):
        pass
    
    def compute_loss(self):
        pass
    
    def zero_grad(self):
        self.opt.zero_grad()
    
    def before_step(self):
        nn.utils.clip_grad_norm_(self.model.parameters(), self.clip)
    
    def step(self):
        self.opt.step()
    
    def after_batch(self):
        pass
    
    def after_train(self):
        pass
    
    def one_batch(self, batch):
        batch = to_device(batch)
        x,y = batch
        if not isinstance(x, (list, tuple)):
            x = [x]
        output = self.model(*x)
        loss = self.loss_function(output, y)
        return loss
    
    def train_supervised(self, bs, epochs, lr, percent_valid=0.05):
        
        train_ds, valid_ds = self.dataset.split(percent_valid)
        
        train_dl = train_ds.dataloader(bs, shuffle=True)
        valid_dl = valid_ds.dataloader(bs)
        
        scheduler = optim.lr_scheduler.OneCycleLR(self.opt, max_lr=lr,
                                                 steps_per_epoch=len(train_dl), epochs=10)

        mb = master_bar(range(epochs))
        mb.write(['Epoch', 'Train Loss', 'Valid  Loss', 'Time'], table=True)
        for epoch in mb:
            start = time.time()
            train_losses = []
            
            for batch in progress_bar(train_dl, parent=mb):
                
                loss = self.one_batch(batch)

                self.zero_grad()
                loss.backward()
                self.step()
                scheduler.step()
                train_losses.append(loss.detach().cpu())
                mb.child.comment = f"{train_losses[-1]:.5f}"
                
            with torch.no_grad():
                valid_losses = []
                for batch in progress_bar(valid_dl, parent=mb):

                    loss = self.one_batch(batch)
                    valid_losses.append(loss.detach().cpu())
                    mb.child.comment = f"{valid_losses[-1]:.5f}"
                    
            train_loss = smooth_batches(train_losses)
            valid_loss = smooth_batches(valid_losses)
            end = time.time() - start
            mb.write([epoch, f'{train_losses[-1]:.5f}', 
                      f'{valid_losses[-1]:.5f}', f'{format_time(end)}'], table=True)
    
    def update_dataset(self, dataset):
        self.dataset = dataset
        
    def update_dataset_from_inputs(self, *dataset_inputs):
        dataset = self.dataset.new(*dataset_inputs)
        self.update_dataset(dataset)
    
    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=get_model_device(self.model))
        
        self.model.load_state_dict(state_dict)

    def save_weights(self, filename):
        
        state_dict = self.model.state_dict()
        torch.save(state_dict, filename)
    

In [None]:
# export
    
class PredictiveAgent(Agent):
    
    def predict_tensor(self, x):
        if not isinstance(x, (list, tuple)):
            x = [x]
        output = self.model(*x)
        return output
        
    def predict_data(self, data):
        ds = self.dataset.new(data, [0 for i in data])
        batch = ds.collate_function([ds[i] for i in range(len(ds))])
        batch = to_device(batch)
        x,y = batch
        return self.predict_tensor(x)

In [None]:
# export

class BaselineAgent(Agent):
    def __init__(self, model, loss_function, dataset, base_update=0.99,
                 base_update_iter=10, base_model=True, opt_kwargs={}, 
                 clip=1., name='baseline_agent'):
        super().__init__(model, loss_function, dataset, opt_kwargs, clip, name)
        
        self.set_models(base_model)
        self.base_update = base_update
        self.base_update_iter = base_update_iter
        
    def after_batch(self):
        log = self.environment.log
        iterations = log.iterations
        if iterations%self.base_update_iter == 0 and iterations>0:
            self.update_base_model()
        
    def set_models(self, base_model):
        
        if base_model==True:
            self.base_model = copy.deepcopy(self.model)
        else:
            self.base_model = base_model
            
        try:
            to_device(self.base_model)
        except:
            pass
            
    def base_to_model(self):
        if type(self.base_model)==type(self.model):
            self.base_model.load_state_dict(self.model.state_dict())
            
    def model_to_base(self):
        if type(self.base_model)==type(self.model):
            self.model.load_state_dict(self.base_model.state_dict())
            
    def update_base_model(self):
        if type(self.base_model)==type(self.model):
            if self.base_update < 1:
                merge_models(self.base_model, self.model, alpha=self.base_update)
                
    def save_weights(self, filename):
        state_dict = {}
        state_dict['model'] = self.model.state_dict()
        
        if isinstance(self.base_model, nn.Module):
            state_dict['base_model'] = self.base_model.state_dict()
        else:
            state_dict['base_model'] = None
            
        torch.save(state_dict, filename)
        
    def load_weights(self, filename):
        state_dict = torch.load(filename, map_location=get_model_device(self.model))
        
        self.model.load_state_dict(state_dict['model'])
        
        if isinstance(self.base_model, nn.Module):
            self.base_model.load_state_dict(state_dict['base_model'])


In [None]:
# export

class CriticAgent(BaselineAgent):
    
    def predict_tensor(self, x, baseline=False):
        if not type(x)==list:
            x = [x]
        
        if baseline:
            if isinstance(self.base_model, nn.Module):
                output = self.base_model(*x)
            else:
                output = None
        else:
            output = self.model(*x)
            
        return output
        
    def predict_data(self, data):
        ds = self.dataset.new(data, [0 for i in data])
        batch = ds.collate_function([ds[i] for i in range(len(ds))])
        batch = to_device(batch)
        x,y = batch
        return self.predict_tensor(x)
    
    def get_model_outputs(self):
        env = self.environment
        batch_state = env.batch_state
        x = batch_state.x
        y = batch_state.y
        
        preds = self.predict_tensor(x, baseline=False)
        batch_state.model_output = preds
        
        with torch.no_grad():
            base_preds = self.predict_tensor(x, baseline=True)
            batch_state.base_output = base_preds
    


In [None]:
# export

class GenerativeAgent(BaselineAgent):
    def __init__(self, model, vocab, loss_function, dataset, 
                 base_update=0.99, base_update_iter=10, base_model=True, 
                 opt_kwargs={}, clip=1., name='generative_agent'):
        super().__init__(model, loss_function, dataset, 
                         base_update=base_update, 
                         base_update_iter=base_update_iter, 
                         base_model=base_model, 
                         opt_kwargs=opt_kwargs,
                         clip=clip,
                         name=name)
        
        self.vocab = vocab
        
    def reconstruct(self, preds):
        return maybe_parallel(self.vocab.reconstruct, [i for i in preds.detach().cpu()])
        
    def after_sample(self):
        env = self.environment
        batch_state = env.batch_state
        sequences = batch_state.samples
                
        batch_ds = self.dataset.new(sequences)
        batch = batch_ds.collate_function([batch_ds[i] for i in range(len(batch_ds))])
        batch = to_device(batch)
        bs = len(batch_ds)
        x,y = batch
            
        batch_state.x = x
        batch_state.y = y
        batch_state.bs = bs
        mask = ~(y==self.vocab.stoi['pad'])
        batch_state.mask = mask
        batch_state.lengths = mask.sum(-1)
        batch_state.sl = y.shape[-1]
        batch_state.rewards = to_device(torch.zeros(bs))
        batch_state.trajectory_rewards = to_device(torch.zeros(y.shape))
        
    def get_rl_tensors(self, model, x, y, latent_info, sources):
        if latent_info:
            latent_sources = []
            output_tensors = []
            for (latent_source, latents) in latent_info.items():
                latent_sources.append(latent_source)
#                 latent_mask = torch.tensor([i==latent_source for i in sources]).bool()
#                 latents = self.agent.latents[latent_idxs]
                out = self.agent.model.get_rl_tensors(subset_tensor(x, latent_mask), 
                                                      subset_tensor(y, latent_mask),
                                                      latent=latents)
                output_tensors.append(out)
                
            non_latent_mask = torch.tensor([not i in latent_sources for i in sources]).bool()
            if non_latent_mask.sum()>0:
                out = model.get_rl_tensors(subset_tensor(x, non_latent_mask), 
                                           subset_tensor(y, non_latent_mask))
                output_tensors.append(out)
            
            mo = torch.cat([i[0] for i in output_tensors], 0)
            mlp = torch.cat([i[1] for i in output_tensors], 0)
            mglp = torch.cat([i[2] for i in output_tensors], 0)
            me = torch.cat([i[3] for i in output_tensors], 0)
            
        else:
            mo, mlp, mglp, me = model.get_rl_tensors(x,y)
            
        return mo, mlp, mglp, me 
    
    def get_model_outputs(self):
            
        env = self.environment
        batch_state = env.batch_state
        
        x = batch_state.x
        y = batch_state.y
        sources = batch_state.sources
        latent_info = batch_state.latent_data
            
        mo, mlp, mglp, me = self.get_rl_tensors(self.model, x, y, latent_info, sources)
        mprob = mlp.exp()
        
        batch_state.model_output = mo
        batch_state.model_logprobs = mlp
        batch_state.model_gathered_logprobs = mglp
        batch_state.model_encoded = me
        batch_state.y_gumbel = F.one_hot(y, len(self.vocab.itos)) + mprob - mprob.detach()
        batch_state.value_input = me
        
#         if self.agent.value_head is not None:
#             value_predictions = self.agent.value_head(me)
#             with torch.no_grad():
#                 base_value_predictions = self.agent.base_value_head(me)
#         else:
#             value_predictions = None
#             base_value_predictions = None
            
#         self.batch_state.state_values = value_predictions
#         self.batch_state.ref_state_values = base_value_predictions
        
        if self.base_model is not None:
            with torch.no_grad():
                bo, blp, bglp, be = self.get_rl_tensors(self.base_model, x, y, latent_info, sources)    
        else:
            bo, blp, bglp, be = None, None, None, None
            
        batch_state.base_output = bo
        batch_state.base_logprobs = blp
        batch_state.base_gathered_logprobs = bglp
        batch_state.base_encoded = be


In [None]:
# export

class SupevisedCB(Callback):
    def __init__(self, agent, frequency, base_update, percentile, 
                 lr, bs, log_term='rewards'):
        super().__init__('supervised', order=1000)
        self.agent = agent
        self.frequency = frequency
        self.base_update = base_update
        self.percentile = percentile
        self.lr = lr
        self.bs = bs
        self.log_term = log_term
        
    def after_batch(self):
        env = self.environment
        iterations = self.environment.log.iterations
        
        if iterations>0 and iterations%self.frequency==0:
            self.train_model()
            
            
    def train_model(self):
        env = self.environment
        df = log_to_df(env.log.log, ['samples', self.log_term])
        df.drop_duplicates(subset='samples', inplace=True)
        df = df[df.rewards>np.percentile(df.rewards.values, self.percentile)]

        self.agent.update_dataset_from_inputs(df.samples.values)
        self.agent.train_supervised(self.bs, 1, self.lr)

        merge_models(self.agent.base_model, self.agent.model, alpha=self.base_update)    

In [None]:
# slow

# standard lm

from mrl.vocab import *
from mrl.dataloaders import *
from mrl.g_models import *

df = pd.read_csv('files/smiles.csv')
vocab = CharacterVocab(SMILES_CHAR_VOCAB)

ds = Text_Dataset(df.smiles.values, vocab)

d_vocab = len(vocab.itos)
d_embedding = 256
d_hidden = 1024
n_layers = 3
input_dropout = 0.3
lstm_dropout = 0.3
bos_idx = vocab.stoi['bos']
bidir = False
tie_weights = True

model = LSTM_LM(d_vocab, 
                d_embedding,
                d_hidden, 
                n_layers,
                input_dropout,
                lstm_dropout,
                bos_idx, 
                bidir, 
                tie_weights)

model.load_state_dict(torch.load('untracked_files/lstm_lm_zinc.pt'))

<All keys matched successfully>

In [None]:
agent = GenerativeAgent(model, vocab, CrossEntropy(), ds, opt_kwargs={'lr':1e-4})

In [None]:
agent.train_supervised(64, 1, 1e-4)

Epoch,Train Loss,Valid Loss,Time
0,1.17034,1.2267,00:07


In [None]:
prior = NormalPrior(torch.zeros((512,)), torch.zeros((512,)))

In [None]:
to_device(prior)

NormalPrior()

In [None]:

# class AgentCallback(Callback):
#     def __init__(self, agent, name, clip=1.):
#         super().__init__(order=20)
#         self.agent = agent
#         self.name = name
#         self.clip = clip
        
#     def zero_grad(self):
#         self.agent.zero_grad()
    
#     def before_step(self):
#         nn.utils.clip_grad_norm_(self.agent.model.parameters(), self.clip)
        
#     def step(self):
#         self.agent.step()
        
#     def after_sample(self):
#         env = self.environment
#         sequences = self.batch_state.samples
#         diversity = len(set(sequences))/len(sequences)
                
#         bs = len(sequences)
#         self.batch_state.rewards = to_device(torch.zeros(bs))
        
#     def get_model_outputs(self):
#         # get relevant model outputs
#         pass
        
# class GenAgentCallback(AgentCallback):
#     def __init__(self, agent, name, contrastive=False):
#         super().__init__(agent, name)
#         self.contrastive = contrastive
    
#     def after_sample(self):
#         env = self.environment
#         sequences = self.batch_state.samples
                
#         batch_ds = self.agent.dataset.new(sequences)
#         batch = batch_ds.collate_function([batch_ds[i] for i in range(len(batch_ds))])
#         batch = to_device(batch)
#         bs = len(batch_ds)
#         x,y = batch
            
#         self.batch_state.x = x
#         self.batch_state.y = y
#         self.batch_state.bs = bs
#         mask = ~(y==self.agent.vocab.stoi['pad'])
#         self.batch_state.mask = mask
#         self.batch_state.lengths = mask.sum(-1)
#         self.batch_state.sl = y.shape[-1]
#         self.batch_state.rewards = to_device(torch.zeros(bs))
#         self.batch_state.trajectory_rewards = to_device(torch.zeros(y.shape))
        
#     def subset_tensor(self, x, mask):
#         if type(x)==list:
#             x = [i[mask] for i in x]
#         else:
#             x = x[mask]
        
#         return x
    
#     def get_rl_tensors(self, model, x, y, latent_info, sources):
#         if latent_info:
#             latent_sources = []
#             output_tensors = []
#             for (latent_source, latent_idxs) in latent_info.items():
#                 latent_sources.append(latent_source)
#                 latent_mask = torch.tensor([i==latent_source for i in sources]).bool()
#                 latents = self.agent.latents[latent_idxs]
#                 out = self.agent.model.get_rl_tensors(self.subset_tensor(x, latent_mask), 
#                                                       self.subset_tensor(y, latent_mask),
#                                                       latent=latents)
#                 output_tensors.append(out)
                
#             non_latent_mask = torch.tensor([not i in latent_sources for i in sources]).bool()
#             if non_latent_mask.sum()>0:
#                 out = model.get_rl_tensors(self.subset_tensor(x, non_latent_mask), 
#                                                       self.subset_tensor(y, non_latent_mask))
#                 output_tensors.append(out)
            
#             mo = torch.cat([i[0] for i in output_tensors], 0)
#             mlp = torch.cat([i[1] for i in output_tensors], 0)
#             mglp = torch.cat([i[2] for i in output_tensors], 0)
#             me = torch.cat([i[3] for i in output_tensors], 0)
            
#         else:
#             mo, mlp, mglp, me = model.get_rl_tensors(x,y)
            
#         return mo, mlp, mglp, me 
        
#     def get_model_outputs(self):
            
#         x = self.batch_state.x
#         y = self.batch_state.y
#         sources = self.batch_state.sources
#         latent_info = self.batch_state.latent_data
            
#         mo, mlp, mglp, me = self.get_rl_tensors(self.agent.model, x, y, latent_info, sources)
#         mprob = mlp.exp()
        
#         self.batch_state.model_output = mo
#         self.batch_state.model_logprobs = mlp
#         self.batch_state.model_gathered_logprobs = mglp
#         self.batch_state.model_encoded = me
#         self.batch_state.y_gumbel = F.one_hot(y, len(self.agent.vocab.itos)) + mprob - mprob.detach()
        
#         if self.agent.value_head is not None:
#             value_predictions = self.agent.value_head(me)
#             with torch.no_grad():
#                 base_value_predictions = self.agent.base_value_head(me)
#         else:
#             value_predictions = None
#             base_value_predictions = None
            
#         self.batch_state.state_values = value_predictions
#         self.batch_state.ref_state_values = base_value_predictions
        
#         if self.agent.base_model is not None:
#             with torch.no_grad():
# #                 bo, blp, bglp, be = self.agent.base_model.get_rl_tensors(x,y)
#                 bo, blp, bglp, be = self.get_rl_tensors(self.agent.base_model, x, y, latent_info, sources)    
#         else:
#             bo, blp, bglp, be = None, None, None, None
            
#         self.batch_state.reference_output = bo
#         self.batch_state.reference_logprobs = blp
#         self.batch_state.reference_gathered_logprobs = bglp
#         self.batch_state.reference_encoded = be
        


In [None]:

# class BaselineAgent(Agent):
#     def __init__(self, model, loss_function, dataset, base_update=0.99, v_update=0.95,
#                 base_model=True, value_head=None, opt_kwargs={}, vopt_kwargs={}):
#         super().__init__(model, loss_function, dataset, opt_kwargs)
        
#         self.opts = [self.opt]
#         self.set_models(base_model, value_head, vopt_kwargs)
#         self.base_update = base_update
#         self.v_update = v_update
        
#     def set_models(self, base_model, value_head, vopt_kwargs):
        
#         if base_model==True:
#             self.base_model = copy.deepcopy(self.model)
#         else:
#             self.base_model = base_model
            
#         try:
#             to_device(self.base_model)
#         except:
#             pass
        
#         self.value_head = value_head
#         if self.value_head is not None:
#             self.base_value_head = copy.deepcopy(self.value_head)
#             to_device(self.value_head)
#             to_device(self.base_value_head)
            
#             self.value_opt = self.get_opt(self.value_head.parameters(), **vopt_kwargs)
#             self.opts.append(self.value_opt)
            
#     def zero_grad(self):
#         for opt in self.opts:
#             opt.zero_grad()
            
#     def step(self):
#         for opt in self.opts:
#             opt.step()
            
#     def base_to_model(self):
#         if type(self.base_model)==type(self.model):
#             self.base_model.load_state_dict(self.model.state_dict())
            
#     def model_to_base(self):
#         if type(self.base_model)==type(self.model):
#             self.model.load_state_dict(self.base_model.state_dict())
            
#     def update_base_models(self):
#         if type(self.base_model)==type(self.model):
#             if self.base_update < 1:
#                 merge_models(self.base_model, self.model, alpha=self.base_update)
            
#         if self.value_head is not None:
#             if self.v_update < 1:
#                 merge_models(self.base_value_head, self.value_head, alpha=self.v_update)
                
#     def save_weights(self, filename):
#         state_dict = {}
#         state_dict['model'] = self.model.state_dict()
        
#         if isinstance(self.base_model, nn.Module):
#             state_dict['base_model'] = self.base_model.state_dict()
#         else:
#             state_dict['base_model'] = None
            
#         if isinstance(self.value_head, nn.Module):
#             state_dict['value_head'] = self.value_head.state_dict()
#             state_dict['base_value_head'] = self.base_value_head.state_dict()
#         else:
#             state_dict['value_head'] = None
#             state_dict['base_value_head'] = None
            
#         torch.save(state_dict, filename)
        
#     def load_weights(self, filename):
#         state_dict = torch.load(filename, map_location=get_model_device(self.model))
        
#         self.model.load_state_dict(state_dict['model'])
        
#         if isinstance(self.base_model, nn.Module):
#             self.base_model.load_state_dict(state_dict['base_model'])
            
#         if isinstance(self.value_head, nn.Module):
#             self.value_head.load_state_dict(state_dict['value_head'])
#             self.base_value_head.load_state_dict(state_dict['base_value_head'])


In [None]:

# class GenerativeAgent(BaselineAgent):
#     def __init__(self, model, vocab, loss_function, dataset, base_update=0.99,
#                  v_update=0.95, base_model=True, value_head=None, latents=None,
#                  opt_kwargs={}, vopt_kwargs={}, lopt_kwargs={}):
#         super().__init__(model, loss_function, dataset, 
#                          base_update=base_update, v_update=v_update,
#                          base_model=base_model, value_head=value_head, 
#                          opt_kwargs=opt_kwargs, vopt_kwargs=vopt_kwargs)
        
#         self.vocab = vocab
#         self.set_latent(latents, lopt_kwargs)
        
#     def set_latent(self, latents, lopt_kwargs):
                    
#         self.latents = latents
#         if self.latents is not None:
#             self.latents = to_device(self.latents)
#             self.latents = nn.Parameter(self.latents)
#             self.latent_opt = self.get_opt([self.latents], **lopt_kwargs)
#             self.opts.append(self.latent_opt)
            
#     def reconstruct(self, preds):
#         return maybe_parallel(self.vocab.reconstruct, [i for i in preds.detach().cpu()])
    
#     def reconstruct_trajectory(self, preds):
#         trajectories = maybe_parallel(self.vocab.reconstruct_trajectory, [i for i in preds.detach().cpu()])
#         return trajectories
    
#     def get_batch_params(self, model_output):
#         x = model_output['x']
#         y = model_output['y']
#         mask = ~(y==self.vocab.stoi['pad'])
#         lengths = mask.sum(-1)
#         sl = y.shape[-1]
#         trajectories = self.reconstruct_trajectory(y)
#         smiles = [i[-1] if i else '' for i in trajectories]
        
#         model_output['mask'] = mask
#         model_output['lengths'] = lengths
#         model_output['sl'] = sl
#         model_output['sequences'] = smiles
#         model_output['sequence_trajectories'] = trajectories
        
#         return model_output
        
    
#     def get_model_outputs(self, model_output):
#         x = model_output['x']
#         y = model_output['y']
#         latent = model_output['latent']
#         mo, mlp, mglp, me = self.model.get_rl_tensors(x,y,latent=latent)
#         mprob = mlp.exp()
    
#         model_output['model_output'] = mo
#         model_output['model_logprobs'] = mlp
#         model_output['model_gathered_logprobs'] = mglp
#         model_output['model_encoded'] = me
#         model_output['y_gumbel'] = F.one_hot(y, len(self.vocab.itos)) + mprob - mprob.detach()
        
#         if self.value_head is not None:
#             value_predictions = self.value_head(me)
#             with torch.no_grad():
#                 base_value_predictions = self.base_value_head(me)
#         else:
#             value_predictions = None
#             base_value_predictions = None
            
#         model_output['state_values'] = value_predictions
#         model_output['old_state_values'] = base_value_predictions
        
#         if self.base_model is not None:
#             with torch.no_grad():
#                 bo, blp, bglp, be = self.base_model.get_rl_tensors(x,y)
#         else:
#             bo, blp, bglp, be = None, None, None, None

#         model_output['reference_output'] = bo
#         model_output['reference_logprobs'] = blp
#         model_output['reference_gathered_logprobs'] = bglp
#         model_output['reference_encoded'] = be
    
#         return model_output

In [None]:

# class CriticAgent(BaselineAgent):
#     def __init__(self, model, loss_function, dataset, base_update=0.99,
#                 base_model=True, opt_kwargs={}):
#         super().__init__(model, loss_function, dataset, 
#                          base_update=base_update,
#                          base_model=base_model, value_head=None, 
#                          opt_kwargs=opt_kwargs, vopt_kwargs={})
    
#     def predict_tensor(self, x, baseline=False):
#         if not type(x)==list:
#             x = [x]
        
#         if baseline:
#             output = self.base_model(*x)
#         else:
#             output = self.model(*x)
        
#     def predict_data(self, data):
#         ds = self.dataset.new(data, [0 for i in data])
#         batch = ds.collate_function([ds[i] for i in range(len(ds))])
#         batch = to_device(batch)
#         x,y = batch
#         return self.predict_tensor(x)
    
#     def get_model_outputs(self, model_output):
#         x = model_output['x']
#         y = model_output['y']
        
#         model_output['model_output'] = self.predict_tensor(x, baseline=False)
#         model_output['reference_output'] = self.predict_tensor(x, baseline=True)
    
#         return model_output