# Landscape Modification (LM) on LSTM + PTB Training Notebook

Here we train 1,2,3 layer LSTMs using AdaBelief-LM. Code has been adapted from the AdaBelief authors' training pipeline, available on their GitHub Repository. This notebook is configured for Kaggle.

## Install packages and set random seeds

In [None]:
!pip install adabelief_pytorch
!pip install torch
!pip install matplotlib
!pip install portalocker

import sys
sys.path.append("/kaggle/input/optimizer/")

from adabelief_lm import AdaBelief_LM

In [None]:
import time
import math
import torch
import torch.nn as nn
from torchtext.datasets import PennTreebank
from torchtext.data.utils import get_tokenizer
from collections import Counter
from adabelief_pytorch import AdaBelief
import os
import numpy as np
import random
import portalocker

# Set the random seed manually for reproducibility.
SEED = 141 # 141 #6 #42
torch.manual_seed(SEED)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

random.seed(SEED)
np.random.seed(SEED)
torch.use_deterministic_algorithms(True)

if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

## Tokenize and Batch the data

In [None]:
import os
import torch

from collections import Counter
from torchtext.datasets import PennTreebank


class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []
        self.counter = Counter()
        self.total = 0

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        token_id = self.word2idx[word]
        self.counter[token_id] += 1
        self.total += 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self):
        self.dictionary = Dictionary()
        self.train = self.tokenize(PennTreebank(split='train'))
        self.valid = self.tokenize(PennTreebank(split='valid'))
        self.test = self.tokenize(PennTreebank(split='test'))


    def tokenize(self, lines:str):
        """Tokenizes a list of strings"""
        tokens = 0
        for line in lines:
            words = line.split() + ['<eos>']
            tokens += len(words)
            for word in words:
                self.dictionary.add_word(word)

        # Tokenize all the content
        ids = torch.LongTensor(tokens)
        token = 0
        for line in lines:
            words = line.split() + ['<eos>']
            for word in words:
                ids[token] = self.dictionary.word2idx[word]
                token += 1

        return ids

In [None]:
# Batchify function
def batchify(data, bsz, device='cuda'):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)


def get_batch(source, i, mean_bptt = 70, seq_len=None, evaluation=False):
    seq_len = min(seq_len if seq_len else mean_bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target


# Repackage hidden function
def repackage_hidden(h):
    """Wraps hidden states in new Tensors,
    to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)

In [None]:
corpus = Corpus()
train_data = batchify(corpus.train, 20, 'cuda')
valid_data = batchify(corpus.valid, 10, 'cuda')
test_data = batchify(corpus.test, 1, 'cuda')

ntokens = len(corpus.dictionary)

## Weight Drop Implementation

In [None]:
import torch
from torch.nn import Parameter
from functools import wraps

class WeightDrop(torch.nn.Module):
    def __init__(self, module, weights, dropout=0, variational=False):
        super(WeightDrop, self).__init__()
        self.module = module
        self.weights = weights
        self.dropout = dropout
        self.variational = variational
        self._setup()

    def widget_demagnetizer_y2k_edition(*args, **kwargs):
        # We need to replace flatten_parameters with a nothing function
        # It must be a function rather than a lambda as otherwise pickling explodes
        # We can't write boring code though, so ... WIDGET DEMAGNETIZER Y2K EDITION!
        # (╯°□°）╯︵ ┻━┻
        return

    def _setup(self):
        # Terrible temporary solution to an issue regarding compacting weights re: CUDNN RNN
        if issubclass(type(self.module), torch.nn.RNNBase):
            self.module.flatten_parameters = self.widget_demagnetizer_y2k_edition

        for name_w in self.weights:
            print('Applying weight drop of {} to {}'.format(self.dropout, name_w))
            w = getattr(self.module, name_w)
            del self.module._parameters[name_w]
            self.module.register_parameter(name_w + '_raw', Parameter(w.data))

    def _setweights(self):
        for name_w in self.weights:
            raw_w = getattr(self.module, name_w + '_raw')
            w = None
            if self.variational:
                mask = torch.autograd.Variable(torch.ones(raw_w.size(0), 1))
                if raw_w.is_cuda: mask = mask.cuda()
                mask = torch.nn.functional.dropout(mask, p=self.dropout, training=True)
                w = torch.nn.Parameter(mask.expand_as(raw_w) * raw_w)
            else:
                w = torch.nn.Parameter(torch.nn.functional.dropout(raw_w, p=self.dropout, training=self.training))
            setattr(self.module, name_w, w)

    def forward(self, *args):
        self._setweights()
        return self.module.forward(*args)

## Cross Entropy Loss Implementation

In [None]:
from collections import defaultdict

import torch
import torch.nn as nn

import numpy as np


class SplitCrossEntropyLoss(nn.Module):
    r'''SplitCrossEntropyLoss calculates an approximate softmax'''
    def __init__(self, hidden_size, splits, verbose=False):
        # We assume splits is [0, split1, split2, N] where N >= |V|
        # For example, a vocab of 1000 words may have splits [0] + [100, 500] + [inf]
        super(SplitCrossEntropyLoss, self).__init__()
        self.hidden_size = hidden_size
        self.splits = [0] + splits + [100 * 1000000]
        self.nsplits = len(self.splits) - 1
        self.stats = defaultdict(list)
        self.verbose = verbose
        # Each of the splits that aren't in the head require a pretend token, we'll call them tombstones
        # The probability given to this tombstone is the probability of selecting an item from the represented split
        if self.nsplits > 1:
            self.tail_vectors = nn.Parameter(torch.zeros(self.nsplits - 1, hidden_size))
            self.tail_bias = nn.Parameter(torch.zeros(self.nsplits - 1))

    def logprob(self, weight, bias, hiddens, splits=None, softmaxed_head_res=None, verbose=False):
        # First we perform the first softmax on the head vocabulary and the tombstones
        if softmaxed_head_res is None:
            start, end = self.splits[0], self.splits[1]
            head_weight = None if end - start == 0 else weight[start:end]
            head_bias = None if end - start == 0 else bias[start:end]
            # We only add the tombstones if we have more than one split
            if self.nsplits > 1:
                head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors])
                head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias])

            # Perform the softmax calculation for the word vectors in the head for all splits
            # We need to guard against empty splits as torch.cat does not like random lists
            head_res = torch.nn.functional.linear(hiddens, head_weight, bias=head_bias)
            softmaxed_head_res = torch.nn.functional.log_softmax(head_res, dim=-1)

        if splits is None:
            splits = list(range(self.nsplits))

        results = []
        running_offset = 0
        for idx in splits:

            # For those targets in the head (idx == 0) we only need to return their loss
            if idx == 0:
                results.append(softmaxed_head_res[:, :-(self.nsplits - 1)])

            # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone)
            else:
                start, end = self.splits[idx], self.splits[idx + 1]
                tail_weight = weight[start:end]
                tail_bias = bias[start:end]

                # Calculate the softmax for the words in the tombstone
                tail_res = torch.nn.functional.linear(hiddens, tail_weight, bias=tail_bias)

                # Then we calculate p(tombstone) * p(word in tombstone)
                # Adding is equivalent to multiplication in log space
                head_entropy = (softmaxed_head_res[:, -idx]).contiguous()
                tail_entropy = torch.nn.functional.log_softmax(tail_res, dim=-1)
                results.append(head_entropy.view(-1, 1) + tail_entropy)

        if len(results) > 1:
            return torch.cat(results, dim=1)
        return results[0]

    def split_on_targets(self, hiddens, targets):
        # Split the targets into those in the head and in the tail
        split_targets = []
        split_hiddens = []

        # Determine to which split each element belongs (for each start split value, add 1 if equal or greater)
        # This method appears slower at least for WT-103 values for approx softmax
        #masks = [(targets >= self.splits[idx]).view(1, -1) for idx in range(1, self.nsplits)]
        #mask = torch.sum(torch.cat(masks, dim=0), dim=0)
        ###
        # This is equally fast for smaller splits as method below but scales linearly
        mask = None
        for idx in range(1, self.nsplits):
            partial_mask = targets >= self.splits[idx]
            mask = mask + partial_mask if mask is not None else partial_mask
        ###
        #masks = torch.stack([targets] * (self.nsplits - 1))
        #mask = torch.sum(masks >= self.split_starts, dim=0)
        for idx in range(self.nsplits):
            # If there are no splits, avoid costly masked select
            if self.nsplits == 1:
                split_targets, split_hiddens = [targets], [hiddens]
                continue
            # If all the words are covered by earlier targets, we have empties so later stages don't freak out
            if sum(len(t) for t in split_targets) == len(targets):
                split_targets.append([])
                split_hiddens.append([])
                continue
            # Are you in our split?
            tmp_mask = mask == idx
            split_targets.append(torch.masked_select(targets, tmp_mask))
            split_hiddens.append(hiddens.masked_select(tmp_mask.unsqueeze(1).expand_as(hiddens)).view(-1, hiddens.size(1)))
        return split_targets, split_hiddens

    def forward(self, weight, bias, hiddens, targets, verbose=False):
        if self.verbose or verbose:
            for idx in sorted(self.stats):
                print('{}: {}'.format(idx, int(np.mean(self.stats[idx]))), end=', ')
            print()

        total_loss = None
        if len(hiddens.size()) > 2: hiddens = hiddens.view(-1, hiddens.size(2))

        split_targets, split_hiddens = self.split_on_targets(hiddens, targets)

        # First we perform the first softmax on the head vocabulary and the tombstones
        start, end = self.splits[0], self.splits[1]
        head_weight = None if end - start == 0 else weight[start:end]
        head_bias = None if end - start == 0 else bias[start:end]

        # We only add the tombstones if we have more than one split
        if self.nsplits > 1:
            head_weight = self.tail_vectors if head_weight is None else torch.cat([head_weight, self.tail_vectors])
            head_bias = self.tail_bias if head_bias is None else torch.cat([head_bias, self.tail_bias])

        # Perform the softmax calculation for the word vectors in the head for all splits
        # We need to guard against empty splits as torch.cat does not like random lists
        combo = torch.cat([split_hiddens[i] for i in range(self.nsplits) if len(split_hiddens[i])])
        ###
        all_head_res = torch.nn.functional.linear(combo, head_weight, bias=head_bias)
        softmaxed_all_head_res = torch.nn.functional.log_softmax(all_head_res, dim=-1)
        if self.verbose or verbose:
            self.stats[0].append(combo.size()[0] * head_weight.size()[0])

        running_offset = 0
        for idx in range(self.nsplits):
            # If there are no targets for this split, continue
            if len(split_targets[idx]) == 0: continue

            # For those targets in the head (idx == 0) we only need to return their loss
            if idx == 0:
                softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])]
                entropy = -torch.gather(softmaxed_head_res, dim=1, index=split_targets[idx].view(-1, 1))
            # If the target is in one of the splits, the probability is the p(tombstone) * p(word within tombstone)
            else:
                softmaxed_head_res = softmaxed_all_head_res[running_offset:running_offset + len(split_hiddens[idx])]

                if self.verbose or verbose:
                    start, end = self.splits[idx], self.splits[idx + 1]
                    tail_weight = weight[start:end]
                    self.stats[idx].append(split_hiddens[idx].size()[0] * tail_weight.size()[0])

                # Calculate the softmax for the words in the tombstone
                tail_res = self.logprob(weight, bias, split_hiddens[idx], splits=[idx], softmaxed_head_res=softmaxed_head_res)

                # Then we calculate p(tombstone) * p(word in tombstone)
                # Adding is equivalent to multiplication in log space
                head_entropy = softmaxed_head_res[:, -idx]
                # All indices are shifted - if the first split handles [0,...,499] then the 500th in the second split will be 0 indexed
                indices = (split_targets[idx] - self.splits[idx]).view(-1, 1)
                # Warning: if you don't squeeze, you get an N x 1 return, which acts oddly with broadcasting
                tail_entropy = torch.gather(torch.nn.functional.log_softmax(tail_res, dim=-1), dim=1, index=indices).squeeze()
                entropy = -(head_entropy + tail_entropy)
            ###
            running_offset += len(split_hiddens[idx])
            total_loss = entropy.float().sum() if total_loss is None else total_loss + entropy.float().sum()

        return (total_loss / len(targets)).type_as(weight)

In [None]:
criterion = None

if not criterion:
    splits = []
    if ntokens > 500000:
        # One Billion
        # This produces fairly even matrix mults for the buckets:
        # 0: 11723136, 1: 10854630, 2: 11270961, 3: 11219422
        splits = [4200, 35000, 180000]
    elif ntokens > 75000:
        # WikiText-103
        splits = [2800, 20000, 76000]
    print('Using', splits)
    criterion = SplitCrossEntropyLoss(400, splits=splits, verbose=False)

## Model Architecture Implementation

In [None]:
class RNNModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.4, dropouth=0.25, dropouti=0.4, dropoute=0.1, wdrop=0.5, tie_weights = True):
        super(RNNModel, self).__init__()
        self.idrop = nn.Dropout(dropouti)
        self.hdrop = nn.Dropout(dropouth)
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnns = [torch.nn.LSTM(ninp if l == 0 else nhid, nhid if l != nlayers - 1 else (ninp if tie_weights else nhid), 1, dropout=0) for l in range(nlayers)]
        if wdrop:
            self.rnns = [WeightDrop(rnn, ['weight_hh_l0'], dropout=wdrop) for rnn in self.rnns]

        print(self.rnns)
        self.rnns = torch.nn.ModuleList(self.rnns)
        self.decoder = nn.Linear(nhid, ntoken)

        if tie_weights:
            #if nhid != ninp:
            #    raise ValueError('When using the tied flag, nhid must be equal to emsize')
            self.decoder.weight = self.encoder.weight

        self.init_weights()

        self.ninp = ninp
        self.nhid = nhid
        self.nlayers = nlayers
        self.dropout = dropout
        self.dropouti = dropouti
        self.dropouth = dropouth
        self.dropoute = dropoute
        self.tie_weights = tie_weights


    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.fill_(0)
        self.decoder.weight.data.uniform_(-initrange, initrange)

        
    def forward(self, input, hidden, return_h = False):
        emb = self.embedded_dropout(self.encoder, input, dropout=self.dropoute if self.training else 0)

        emb = self.locked_dropout(emb, self.dropouti)

        raw_output = emb
        new_hidden = []
        raw_outputs = []
        outputs = []
        for l, rnn in enumerate(self.rnns):
            if isinstance(rnn, torch.nn.LSTM):
                rnn.flatten_parameters()
            current_input = raw_output
            raw_output, new_h = rnn(raw_output, hidden[l])
            new_hidden.append(new_h)
            raw_outputs.append(raw_output)
            if l != self.nlayers - 1:
                raw_output = self.locked_dropout(raw_output, self.dropouth)
                outputs.append(raw_output)
        hidden = new_hidden

        output = self.locked_dropout(raw_output, self.dropout)
        outputs.append(output)

        result = output.view(output.size(0)*output.size(1), output.size(2))
        if return_h:
            return result, hidden, raw_outputs, outputs
        return result, hidden

    
    def init_hidden(self, bsz):
        weight = next(self.parameters()).data
        return [(weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_(),
                weight.new(1, bsz, self.nhid if l != self.nlayers - 1 else (self.ninp if self.tie_weights else self.nhid)).zero_())
                for l in range(self.nlayers)]
    
    
    def locked_dropout(self, x, dropout=0.5):
        if not self.training or not dropout:
            return x
        mask = x.data.new(1, x.size(1), x.size(2)).bernoulli_(1 - dropout)
        mask = mask / (1 - dropout)
        mask = mask.expand_as(x)
        return mask * x
    
    
    def embedded_dropout(self, embed, words, dropout=0.1, scale=None):
        if dropout:
            mask = embed.weight.data.new().resize_((embed.weight.size(0), 1)).bernoulli_(1 - dropout).expand_as(embed.weight) / (1 - dropout)
            masked_embed_weight = mask * embed.weight
        else:
            masked_embed_weight = embed.weight
        if scale:
            masked_embed_weight = scale.expand_as(masked_embed_weight) * masked_embed_weight

        padding_idx = embed.padding_idx
        if padding_idx is None:
            padding_idx = -1

        X = torch.nn.functional.embedding(words, masked_embed_weight,
        padding_idx, embed.max_norm, embed.norm_type,
        embed.scale_grad_by_freq, embed.sparse
        )
        return X

## Training

Different training function to AdaBelief authors' since we use Adabelief with landscape modification.

In [None]:
def train(model, train_data, batch_size, criterion, optimizer, epoch, mean_bptt=70, alpha= 2, beta = 1, scheduler=None, clip = 0.25):
    total_loss = 0
    intermediate_loss = 0
    start_time = time.time()
    hidden = model.init_hidden(batch_size)
    batch, i = 0, 0
    running_min = 0
    indices = np.random.permutation(train_data.size(0)-1-1)
    while i < train_data.size(0) - 1 - 1:
        bptt = mean_bptt if np.random.random() < 0.95 else mean_bptt / 2.
        # Prevent excessively small or negative sequence lengths
        seq_len = max(5, int(np.random.normal(bptt, 5)))
        # There's a very small chance that it could select a very long sequence length resulting in OOM
        # seq_len = min(seq_len, mean_bptt + 10)

        lr2 = optimizer.param_groups[0]['lr']
        optimizer.param_groups[0]['lr'] = lr2 * seq_len / mean_bptt
        
        model.train()
        data, targets = get_batch(train_data, i, mean_bptt, seq_len=seq_len)
        data, targets = data.long().to('cuda'), targets.long().to('cuda')

        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()
        output, hidden, rnn_hs, dropped_rnn_hs = model(data, hidden, return_h=True)
        raw_loss = criterion(model.decoder.weight, model.decoder.bias, output, targets)

        loss = raw_loss
        # Activiation Regularization
        if alpha: loss = loss + sum(alpha * dropped_rnn_h.pow(2).mean() for dropped_rnn_h in dropped_rnn_hs[-1:])
        # Temporal Activation Regularization (slowness)
        if beta: loss = loss + sum(beta * (rnn_h[1:] - rnn_h[:-1]).pow(2).mean() for rnn_h in rnn_hs[-1:])
        loss.backward()
        
        running_min = min(loss, running_min)
        
        if clip: torch.nn.utils.clip_grad_norm_(params, clip)
        optimizer.step(running_loss = loss, c = running_min)

        intermediate_loss += raw_loss.data
        total_loss += len(data) * raw_loss.data
        optimizer.param_groups[0]['lr'] = lr2

        if batch % 100 == 0 and batch > 0:
            cur_loss = intermediate_loss.item() / 100
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:05.5f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} | ppl {:8.2f} | bpc {:8.3f}'.format(
                epoch, batch, len(train_data) // mean_bptt, optimizer.param_groups[0]['lr'],
                elapsed * 1000 / 100, cur_loss, math.exp(cur_loss), cur_loss / math.log(2)))
            intermediate_loss = 0
            start_time = time.time()
        ###
        batch += 1
        i += seq_len
        
    
    return math.exp(total_loss.item() / len(train_data))


def evaluate(data_source, batch_size=10, mean_bptt = 70):
    # Turn on evaluation mode which disables dropout.
    model.eval()
    total_loss = 0
    hidden = model.init_hidden(batch_size)
    for i in range(0, data_source.size(0) - 1, mean_bptt):
        data, targets = get_batch(data_source, i, mean_bptt, evaluation=True)
        output, hidden = model(data, hidden)
        total_loss += len(data) * criterion(model.decoder.weight, model.decoder.bias, output, targets).data
        hidden = repackage_hidden(hidden)
    return math.exp(total_loss.item() / len(data_source))

In [None]:
ninp = 400
nhid = 1150
nlayers = 2
model = RNNModel(ntokens, ninp, nhid, nlayers).to('cuda')

params = list(model.parameters()) + list(criterion.parameters())
total_params = sum(x.size()[0] * x.size()[1] if len(x.size()) > 1 else x.size()[0] for x in params if x.size())
print('Model total parameters:', total_params)

In [None]:
# Define Optimizer
def my_function(x):    
    return x**2


optimizer = AdaBelief_LM(model.parameters(), function = my_function, eps_iksa= 1,
                         lr = 1e-2, betas=(0.9, 0.999), weight_decay=1.2e-6,
                         eps=1e-12, rectify=False, degenerated_to_sgd=False, weight_decouple = False) #Following recommended

# Initialize structures and basic parameters
num_epochs = 199
model_load_path = '/kaggle/working/lstm_adabelief_lm'
log_load_path = '/kaggle/working/lstm_adabelief_lm_log'
model_save_path = '/kaggle/working/lstm_adabelief_lm'
log_save_path = '/kaggle/working/lstm_adabelief_lm_log'


def load_checkpoint(filepath):
    if os.path.isfile(filepath):
        checkpoint = torch.load(filepath)
        start_epoch = checkpoint['epoch'] + 1

        # Custom loading function to handle WeightDrop
        def load_custom(model, state_dict):
            for name, param in model.named_parameters():
                if 'weight_hh' in name and 'weight_hh' not in state_dict:
                    continue
                param.data.copy_(state_dict[name])

        load_custom(model, checkpoint['model_state_dict'])

        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        except ValueError as e:
            print(f"Error loading optimizer state dict: {e}")
            print("Optimizer state_dict keys:")
            print(optimizer.state_dict().keys())
            print("Checkpoint optimizer_state_dict keys:")
            print(checkpoint['optimizer_state_dict'].keys())

            # Print sizes of parameter groups
            print("Optimizer param group sizes:")
            for group in optimizer.param_groups:
                print(len(group['params']))

            print("Checkpoint param group sizes:")
            for group in checkpoint['optimizer_state_dict']['param_groups']:
                print(len(group['params']))

        train_ppls = checkpoint['train_ppls']
        valid_ppls = checkpoint['valid_ppls']
        print(f"Loaded checkpoint from epoch {start_epoch - 1}")
        return start_epoch, train_ppls, valid_ppls
    else:
        print("No checkpoint found at specified path!")
        return 1, [], []



start_epoch, train_ppls, valid_ppls = load_checkpoint(model_load_path)


for epoch in range(start_epoch, start_epoch + num_epochs + 1):
    if epoch in [100, 145]:
        print('Dividing learning rate by 10')
        for param_group in optimizer.param_groups:
            param_group['lr'] /= 10.

    epoch_start_time = time.time()
    train_ppl = train(model, train_data, 20, criterion, optimizer, epoch)
    train_ppls.append(train_ppl)
    valid_ppl = evaluate(test_data, 1, mean_bptt = 70)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time), valid_ppl))
    print('-' * 89)

    # Save the model
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_ppls': train_ppls,
        'valid_ppls': valid_ppls
    }, model_save_path)

    # Write average losses, perplexities per epoch to file
    if not os.path.exists(log_save_path):
        with open(log_save_path, 'w') as f:
            f.write('train_ppl,valid_ppl\n')
        
    with open(log_save_path, 'a') as f:
        f.write(f'{train_ppl},{valid_ppl}\n')