<a href="https://colab.research.google.com/github/andreaslanga/Lecture-notes/blob/main/stepbystep_v0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import datetime

class StepByStep(object):
  def __init__(self, model, loss_fn, optimizer):
    self.model = model
    self.loss_fn = loss_fn
    self.optimizer = optimizer
    self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    self.model.to(self.device) # Send Model to right device right away

    # Placeholder / Delayed Arguments
    self.train_loader = None
    self.val_loader = None
    self.writer = None

    # Variables
    self.losses = []
    self.val_losses = []
    self.total_epochs = 0

    # Functions
    self.train_step_fn = self._make_train_step_fn()
    self.val_step_fn = self._make_val_step_fn()

  def to(self, device):
    try:
      self.device = device
      self.model.to(self.device)
    except RuntimeError:
      self.device = ('cuda' if torch.cuda.is_available() else 'cpu')
      print(f"Couldn't send it to {device}, sending it to {self.device} instead.")
      self.model.to(self.device)

  def set_loaders(self, train_loader, val_loader):
    self.train_loader = train_loader
    self.val_loader = val_loader

  def set_tensorboard(self, name, folder = 'runs'):
    suffix = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
    self.writer = SummaryWriter(f'{folder}/{name}_{suffix}')

  def _make_train_step_fn(self):
    # Builds function that performs a step in the training loop
    def perform_train_step_fn(x,y):
      self.model.train()
      yhat = self.model(x)
      loss = self.loss_fn(yhat, y)
      loss.backward()
      self.optimizer.step()
      self.optimizer.zero_grad()
      return loss.item()

    return perform_train_step_fn

  def _make_val_step_fn(self):
    def perform_val_step_fn(x,y):
      self.model.eval()
      yhat = self.model(x)
      loss = self.loss_fn(yhat, y)
      return loss.item()

    return perform_val_step_fn

  def _mini_batch(self, validation = False):
    if validation:
      data_loader = self.val_loader
      step_fn = self.val_step_fn
    else:
      data_loader = self.train_loader
      step_fn = self.train_step_fn

    if data_loader is None:
      return None

    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
      x_batch = x_batch.to(self.device)
      y_batch = y_batch.to(self.device)
      mini_batch_loss = step_fn(x_batch, y_batch)
      mini_batch_losses.append(mini_batch_loss)

    loss = np.mean(mini_batch_losses)
    return loss

  def set_seed(self, seed=42):
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.manual_seed(seed)
    np.random.seed(seed)

  def train(self, n_epochs, seed=42):
    #To ensure reproducability of the training process
    self.set_seed(seed)

    for epoch in range(n_epochs):
      # Keep track of number of epochs
      self.total_epochs += 1

      # Performs training using mini-batches
      loss = self._mini_batch(validation=False)
      self.losses.append(loss)

      # Validation
      with torch.no_grad():
        # Performs evaluation using mini-batches
        val_loss = self._mini_batch(validation=True)
        self.val_losses.append(val_loss)

      if self.writer:
        scalars = {'training':loss}
        if val_loss is not None:
          scalars.update({'validation':val_loss})
        self.writer.add_scalars(main_tag='loss',
                                tag_scalar_dict=scalars,
                                global_step=epoch)
        self.writer.close()

  def save_checkpoint(self, filename):
    checkpoint = {
        'model_state_dict': self.model.state_dict(),
        'optimizer_state_dict': self.optimizer.state_dict(),
        'epoch': self.total_epochs,
        'loss': self.losses,
        'val_loss': self.val_losses
    }
    torch.save(checkpoint, filename)

  def load_checkpoint(self, filename):
    checkpoint = torch.load(filename) # Loads dictionary

    # Restore state for model & optimizer
    self.model.load_state_dict(checkpoint['model_state_dict'])
    self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    self.total_epochs = checkpoint['epoch']
    self.losses = checkpoint['loss']
    self.val_losses = checkpoint['val_loss']

    self.model.train() # for resuming training

  def predict(self, x):
    self.model.eval()
    # Takes a Numpy input and make it a float tensor
    x_tensor = torch.as_tensor(x).float()
    # Send Input to device + uses model for predictions
    yhat_tensor = self.model(x_tensor.to(self.device))
    # Set it back to train mode
    self.model.train()
    # Detaches it, brings it to CPU & back to numpy
    return yhat_tensor.detach().cpu().numpy()

  def plot_losses(self):
    fig = plt.figure(figsize=(10,4))
    plt.plot(self.losses, label='Training Loss', c='b')
    plt.plot(self.val_losses, label='Validation Loss', c='r')
    plt.yscale('log')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.tight_layout()
    return fig

  def add_graph(self):
    # Fetches a single mini-batch so we can use add_graph
    if self.train_loader and self.writer:
      x_sample, y_sample = next(iter(self.train_loader))
      self.writer.add_graph(self.model, x_sample.to(self.device))