In [None]:
import os
import torch
import numpy as np
import pandas as pd
from d2l import torch as d2l
import spacy
from tqdm import tqdm, trange
from tqdm.notebook import tqdm_notebook
import pickle
from torch.utils.data import Dataset
import warnings
import wandb
import time

from alive_progress import alive_bar, config_handler

In [None]:
d2l.DATA_HUB['wikitext-2'] = (
    'https://s3.amazonaws.com/research.metamind.io/wikitext/'
    'wikitext-2-v1.zip', '3c914d17d80b1459be871a5039ac23e752a53cbe')
running_dir = r'C:\Code\NLP\BERT\Notebooks'


In [None]:
d2l.DATA_HUB['wikitext-2']
data_dir = d2l.download_extract('wikitext-2', 'wikitext-2')


In [None]:
def save_object(obj, filename):
    with open(filename, 'wb') as outp:
        pickle.dump(obj, outp, pickle.HIGHEST_PROTOCOL)


def load_object(path):
    with open(path, 'rb') as inp:
        return pickle.load(inp)


In [19]:
def _read_wiki(data_dir):
    save_location = os.path.join(running_dir, 'paragraphs_data.pkl')
    if os.path.isfile(save_location):
        return load_object(save_location)

    tokenizer = spacy.load('en_core_web_sm')

    filepath = os.path.join(data_dir, 'wiki.train.tokens')
    print(filepath)
    with open(filepath, 'r', errors='replace') as f:
        lines = f.readlines()

    print('[INFO] File read')
    
    lines = lines[:1000]
    print (len(lines))


    paragraphs = [
        [[token.text.lower() for token in tokenizer(sentence.strip())]
         for sentence in line.strip().split('. ')]
        for line in tqdm(lines)
        if len(line.split('. ')) >= 2
    ]

    save_object(paragraphs, save_location)

    return paragraphs

In [None]:
# Helper Functions
def clip_dicts(ds, max_len):
    return [{k: d[k] for i, k in enumerate(d) if i <= max_len} for d in ds]


In [None]:
# HELPER FUNCTIONS
def _get_next_sentence(next_sentence, paragraphs):
    """Get next sentence from paragraphs . This function 
       randomly decided whether to use a correct or random next sentence
       in the next sentence prediciton task. BERT will then predict whether
       this sentence follows the previous or is new.

    Args:
        next_sentence (list): the origional next sentence
        paragraphs (list[list[str]]): a list of sentences where sentences
                                      is a list of tokens

    Returns:
        list: the next sentence in the sequence to predict on
    """
    if np.random.rand() < 0.5:
        is_next = True
    else:
        paragraph_idx = np.random.randint(0, high=len(paragraphs))
        sentence_idx = np.random.randint(
            0, high=len(paragraphs[paragraph_idx]))
        next_sentence = paragraphs[paragraph_idx][sentence_idx]
        is_next = False

    return next_sentence, is_next


def _get_tokens_and_segments(tokens_a, tokens_b=None):
    """Get tokens of the BERT input sequence and their segment IDs .
       This function is taken from the dl2 library but rewritten here
       for clarity and to reduce reliance on dependecys .
       """
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 0 and 1 are marking segment A and B, respectively
    segments = [0] * (len(tokens_a) + 2)
    if tokens_b is not None:
        tokens += tokens_b + ['<sep>']
        segments += [1] * (len(tokens_b) + 1)
    return tokens, segments


def _generate_nsp_data(paragraph, paragraphs, max_len):
    """Generate a list of tokens for a given paragraph .

    Args:
        paragraph (list): list of words in paragraph
        paragraphs (list): list of paragraphs (obviously)
        max_len (int): maximum length for a single example

    Returns:
        list: list of token, segment tuples 
    """
    nsp_data = list()
    # len(paragraph)-1 because it grabs next sentence wtih [i+1] indexing
    for i in range(len(paragraph)-1):
        # chooses whether to use the true next sentence or a random one
        # BERT will try to distinguish between the two in its nsp task
        sentence = paragraph[i]
        next_sentence, is_next = _get_next_sentence(paragraph[i+1], paragraphs)

        if len(sentence) + len(next_sentence) > max_len:
            continue

        tokens, segments = _get_tokens_and_segments(sentence, next_sentence)

        nsp_data.append((tokens, segments, is_next))

    return nsp_data


def _replace_masked_tokens(tokens, possible_prediction_indexes, num_preds, vocab):
    """Replace tokens to be masked with either "<mask>", the correct word, 
       or a random word selected from vocab.

    Args:
        tokens (list): unmasked tokens to be replaced partially with masked ones
        possible_prediction_indexes ([type]): indexes of non-token characters
                                              to possibly replace with <mask>
                                              or random word tokens
        num_preds (int): number of tokens to theoretically replace
                            with <mask>
        vocab (Vocab): vocabulary object for the dataset

    Returns:
        mlm_tokens (list): tokens input but with some tokens
                           replaced wtih masks and random tokens
        pred_lables_and_positions (list[tuples]): list containing pairs of 
                                                  indexes and their tokens 
    """
    mlm_tokens = [token for token in tokens]
    pred_labels_and_positions = list()

    for i, idx in enumerate(possible_prediction_indexes):
        if i >= num_preds:
            break

        random_num = np.random.random()*100
        if random_num < 80:
            masked_token = '<mask>'

        elif random_num < 90:
            masked_token = tokens[idx]

        else:
            masked_token = np.random.choice(list(vocab.stoi.keys()))

        mlm_tokens[idx] = masked_token
        # Pred labels and positions stores unamsked ground truth y values
        # thats why it indexes tokens instead of mlm_tokens which has been modified
        # with masks
        pred_labels_and_positions.append((idx, tokens[idx]))

    return mlm_tokens, pred_labels_and_positions


def _generate_mlm_data_from_tokens(tokens, vocab):
    # tokens is a list of strings
    possible_prediction_indexes = list()
    # Finds all indexes not containg cls and sep tokens
    # making them potential masked tokens
    for i, token in enumerate(tokens):
        if token not in ['<cls>', '<sep>']:
            possible_prediction_indexes.append(i)

    # Mask 15% of tokens or at minimum 1
    num_masked_tokens = max(1, len(possible_prediction_indexes)*0.15)
    masked_mlm_tokens, pred_labels_and_positions = _replace_masked_tokens(
        tokens, possible_prediction_indexes, num_masked_tokens, vocab
    )  # takes unmasked tokens and masks "num_masked_tokens of them" at positions "possible_pred_positions"

    pred_labels_and_positions = sorted(
        pred_labels_and_positions,
        key=lambda x: x[0]
    )  # sort by indexes replaced tokens

    pred_positions = [idx[0] for idx in pred_labels_and_positions]
    pred_ground_truth = [idx[1] for idx in pred_labels_and_positions]

    # calling vocab () int incodes token paramater
    return (vocab(masked_mlm_tokens), pred_positions, vocab(pred_ground_truth))


In [None]:
class InvalidVocabLengths (Exception):
    behavior = 'raise'

    def __init__(self, stoi, itos):
        behavior = self.__class__.behavior
        if behavior == 'ignore':
            return

        elif behavior == 'warn':
            warnings.warn(
                f'Invalid Vocab stoi/itos lengths; len(stoi): {len(stoi)}, len(itos): {len(itos)}')

        elif behavior == 'raise':
            raise


class DatasetWrapper:
    """A wrapper for Dataset classes and the DatasetAssember class .
       The idea is to make a generalizable class to help visualize
       and process datasets just by inherritence .
    """

    def __init__(self) -> None:
        pass

    def __str__(self):
        return '{} (Datasets: (\n\t{}))'.format(self.__class__.__name__, '\n\t'.join(
            [f'{key}: {value}'for key, value in self.datasets.items()]))

    def visualize_vocab(self, n=30):
        # Entirely untested but theoretically working .
        # Can only be tried once dataset initilization is completed
        # TODO: complete data visualization with avg word lengths and
        #       sentence length plots .
        active_vocab = self.vocab
        frequencies = self.frequencies

        frequencies_df = pd.DataFrame(frequencies).reset_index(drop=False)
        print(frequencies_df)
        frequencies_df.plot.bar()

        return


class DatasetAssembler (DatasetWrapper):
    """Allows support for multiple datasets with advanced
       processing and splits . Could theoretically allow for
       cross validation however that is currently not implemented .
    """

    def __init__(self, paragraphs, **kwargs) -> None:
        """Initialize mutliple datasets to a single object 
           with stored vocabulary .

        Args:
            paragraphs (array): List holding paragraph data
            **kwargs:
                splits (list[percents]): list of percents to split
                                         data into . 

                reserved_tokens (list[str]): list of tokens to 
                                             reserve for nlp modeling
        """

        # Theoretical paragraph preprocessing

        sentences = [
            sentence for paragraph in paragraphs for sentence in paragraph]

        self.vocab = Vocab(sentences, **kwargs)

        if 'splits' not in kwargs.keys():
            kwargs['splits'] = [1]
        # splits should be a list of percentages to split data into
        assert np.sum(kwargs['splits']) == 1, \
            'Invalid split percentages; splits must add to 100%'

        self.splits = list()
        for i, split_percent in enumerate(kwargs['splits']):
            prev_split = np.sum(kwargs['splits'][:i])

            print(f'prev split {prev_split}, split percent: {split_percent}')
            print(
                f'term 1: {int(prev_split*len(paragraphs))}, term 2: {int(prev_split*len(paragraphs)+split_percent*len(paragraphs))}')
            self.splits.append(
                WikiTextDataset(
                    paragraphs[int(prev_split*len(paragraphs)):int(prev_split *
                                                                   len(paragraphs)+split_percent*len(paragraphs))],
                    self.vocab,
                    split_percent,
                    max_example_len=kwargs['max_example_len'] if 'max_example_len' in list(
                        kwargs.keys()) else 60
                )
            )

        if len(self.splits) <= 3:
            # complicated syntax making it possible to assign all three at once while padding
            # validset/testset if there arent enough splits to fill those values
            self._trainset, self._validset, self._testset = [
                split for split in self.splits] + [None]*(3 - len(self.splits))

        self.datasets = {
            'train': self._trainset,
            'valid': self._validset,
            'test': self._testset
        }

    def build_loaders(self, batch_size, shuffle, num_workers):
        self.dataloaders = {}

        for name, dataset in self.datasets.items():
            if dataset is not None:
                loader = torch.utils.data.DataLoader(
                    dataset,
                    batch_size=batch_size,
                    shuffle=shuffle,
                    num_workers=num_workers
                )

            else:
                loader = None

            self.dataloaders[name] = loader

        return self.dataloaders

    def __getitem__(self, idx):
        if isinstance(idx, str):
            return self.datasets[idx]

        elif isinstance(idx, int):
            return self.splits[idx]

        else:
            raise IndexError(f'Invalid dataset idx type {type(idx)}')

    # More methods to access datasets
    @property
    def trainset(self): return self._trainset

    @property
    def validset(self): return self._validset

    @property
    def testset(self): return self._testset

    @property
    def trainloader(self):
        return self.dataloaders['train']

    @property
    def testloader(self):
        return self.dataloaders['test']

    @property
    def validloader(self):
        return self.dataloaders['valid']


class WikiTextDataset (Dataset):
    def __init__(
        self,
        paragraphs,
        vocab,
        dataset_percent,
        max_example_len=None 
    ):
        # dataset_percent is purely for accesibility
        # allows for more readable printing with known
        # percentage of full dataset
        self.dataset_percent = dataset_percent
        self.vocab = vocab
        self.paragraphs = paragraphs
        self.max_example_len = max_example_len
        print (f'max example len: {self.max_example_len}')

        training_examples = list()
        for paragraph in paragraphs:
            training_examples.extend(
                _generate_nsp_data(
                    paragraph,
                    paragraphs,
                    max_example_len
                )
            )

        training_examples = [(_generate_mlm_data_from_tokens(tokens, self.vocab) + (segments, is_next))
                             for tokens, segments, is_next in tqdm(training_examples, desc='Building Training Examples')]

        # A tiny bit of cheating with the d2l library
        # Made padding significantly easier
        (self.all_token_ids, self.all_segments, self.valid_lens,
         self.all_pred_positions, self.all_mlm_weights, self.all_mlm_labels,
         self.nsp_labels) = d2l._pad_bert_inputs(training_examples, max_example_len, self.vocab)

        

    def __getitem__(self, idx):
        data = [self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx], self.all_pred_positions[idx],
                self.all_mlm_weights[idx], self.all_mlm_labels[idx],
                self.nsp_labels[idx]]

        for i, tensor in enumerate(data): 
            if len(tensor.shape) > 0:
                if tensor.shape[0] > self.max_example_len:
                    data[i] = tensor[:self.max_example_len]
                    # print ('hi')

        return tuple(data)
                


    def __len__(self):
        return len(self.all_token_ids)

    def __str__(self):
        return f'WikiTextDataset ({self.dataset_percent*100}% of full dataset)'


class Vocab ():
    """Universal text vocabulary class .
    """

    def __init__(
        self,
        sentences,
        reserved_tokens=None,
        max_vocab_size=np.inf,
        frequency_threshold=2,
        splits=None, # a little bit of poor coding practices
        max_example_len=None # using this to allow for kwargs in Assembler
        # **kwargs
    ):
        """Initialize vocabulary from a list of sentences .

        Args:
            sentences (list): sentences derived from paragraphs array
            reserved_tokens (list, optional): Tokens to not build in initializtion. Defaults to None.
            max_vocab_size (int, optional): maximum number of vocab tokens to build. Defaults to infinite tokens.
            frequency_threshold (int, optional): min number of time a word must appear to be placed in vocab. Defaults to 2.

        Raises:
            InvalidVocabLengths: If vocab lengths are somehow altered so that itos and stoi
                                 are not exactly opposite eachother
        """
        text = []
        for sentence in sentences:
            text.extend([word for word in sentence])

        frequencies = {}
        self.stoi = {token: i for i, token in enumerate(reserved_tokens)}
        self.itos = {v: k for k, v in self.stoi.items()}

        for word in text:
            if word in frequencies:
                frequencies[word] += 1
            else:
                frequencies[word] = 1

        self.sorted_frequencies = dict(
            sorted(frequencies.items(), key=lambda item: item[1], reverse=True))

        idx = len(self.stoi)
        for word, freq in self.sorted_frequencies.items():
            if word not in self.stoi and freq >= frequency_threshold:

                self.stoi[word] = idx
                self.itos[idx] = word

                idx += 1

        if len(self.stoi) != len(self.itos):
            raise InvalidVocabLengths(self.stoi, self.itos)

        if len(self.stoi) > max_vocab_size:
            self.stoi, self.itos = clip_dicts(
                (self.stoi, self.itos), max_vocab_size)

    def __getitem__(self, idx):
        # allows a token to be fetched
        # either by idx or word string

        if isinstance(idx, str):
            return self.stoi[idx]

        elif isinstance(idx, int):
            return self.itos[idx]

        else:
            raise TypeError(f'Invalid index {idx}')

    def __call__(self, tokens):
        # int encodes list of tokens
        assert all([isinstance(token, str) for token in tokens]), \
            'tokens list does not contain strings'

        int_encoded_tokens = []
        for i, token in enumerate(tokens):
            try:
                int_encoded_tokens.append(self.stoi[token])

            except KeyError:
                print(f'KeyError at {i} token: {token}')

        # int_encoded_tokens = [self.stoi[token] for token in tokens]
        return int_encoded_tokens

    def __len__(self):
        if len(self.stoi) != len(self.itos):
            raise InvalidVocabLengths(self.stoi, self.itos)
        return len(self.stoi.keys())

    def __str__(self):
        return ('Vocab: (\nlength: {}\n{})'
                .format(self.__len__(), list(self.stoi.keys())))

    def stoi(self, string):
        return self.stoi[string]

    def itos(self, idx):
        return self.itos[idx]


def load_wikitext():
    paragraphs = _read_wiki(data_dir)
    assembler = DatasetAssembler(
        paragraphs,
        splits=[0.8, 0.2],
        reserved_tokens=['<sep>', 'unk', '<cls>', '<mask>', '<pad>'],
        frequency_threshold=0,
        max_example_len=120
    )

    return assembler
    

In [22]:
# optional assembler name for choosing assembler saves
# appended to "assembler {}" when finding file name
ASSEMBLER_NAME = '64' 

assembler_save_path = os.path.join(running_dir, f'assembler{ASSEMBLER_NAME}.pkl')

if os.path.isfile(assembler_save_path):
    assembler = load_object(assembler_save_path)
    print('assembler loaded from save')


else:
    assembler = load_wikitext()
    batch_size = 32
    shuffle = False
    num_workers = 0

    assembler.build_loaders(batch_size, shuffle, num_workers)

    save_object(assembler, assembler_save_path)


assembler loaded from save


In [None]:
# paragraphs = (_read_wiki(data_dir))
# paragraphs = paragraphs[:200]
# assembler = DatasetAssembler (paragraphs, reserved_tokens = ['<sep>','unk','<cls>','<mask>','<pad>'], splits = [0.9, 0.1], frequency_threshold=0)
# assembler.build_loaders(512, False, 0)

In [None]:
trainloader = assembler.trainloader
validloader = assembler.validloader
testloader = assembler.testloader
print(len(trainloader))

for i, batch in enumerate (trainloader):
    for data in batch:
        print (data.shape)

    print ('')

    if i > 0:
        break 

In [None]:
for (
        tokens_X,
        segments_X,
        valid_lens_x,
        pred_positions_X,
        mlm_weights_X, mlm_Y,
        nsp_y)\
        in trainloader:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)
    print (segments_X)

    break


In [None]:
import torch.nn as nn


class EncoderBlock(nn.Module):
    def __init__(self, key_size, query_size, value_size, num_hiddens,
                 norm_shape, ffn_num_input, ffn_num_hiddens, num_heads,
                 dropout, use_bias=False, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = d2l.MultiHeadAttention(key_size, query_size,
                                                value_size, num_hiddens,
                                                num_heads, dropout, use_bias)

        # Steal d2l multi head attention # TODO: write mutli headed attention myself

        self.norm1 = nn.LayerNorm(norm_shape)
        self.norm2 = nn.LayerNorm(norm_shape)
        self.dropout = nn.Dropout(dropout)

        self.ff = SimpleFF(
            ffn_num_input,
            ffn_num_hiddens,
            num_hiddens
        )

    def forward(self, X, valid_lens):
        Y = self.norm1(
            self.dropout(
                self.attention(X, X, X, valid_lens)
            ) + X
        )

        ff_out = self.ff(Y)

        return self.norm2(self.dropout(ff_out) + Y)


class SimpleFF (nn.Module):
    def __init__(self, ff_num_input, ff_num_hiddens, ff_out):

        super(SimpleFF, self).__init__()

        # Simple feed forward network
        self.feed_forward = nn.ModuleList(
            [
                nn.Linear(ff_num_input, ff_num_hiddens),
                nn.ReLU(),
                nn.Linear(ff_num_hiddens, ff_out)
            ]
        )

    def forward(self, Y):
        for layer in self.feed_forward:
            Y = layer(Y)

        return Y


class Encoder(nn.Module):
    def __init__(
        self, vocab_size, num_hiddens, norm_shape,
        ffn_num_input, ffn_num_hiddens, num_heads,
        num_layers, dropout, max_len=1000, key_size=768,
        query_size=768, value_size=768, **kwargs
    ):

        super(Encoder, self).__init__(**kwargs)

        self.token_embedding = nn.Embedding(vocab_size, num_hiddens)
        self.segment_embedding = nn.Embedding(2, num_hiddens)

        self.blocks = nn.Sequential()
        for i in range(num_layers):
            self.blocks.add_module(
                "{}".format(i),

                EncoderBlock(
                    key_size,
                    query_size,
                    value_size,
                    num_hiddens,
                    norm_shape,
                    ffn_num_input,
                    ffn_num_hiddens,
                    num_heads,
                    dropout,
                    True
                )
            )

        # positional embedding is learnable
        # so we make it a param
        self.pos_embedding = nn.Parameter(torch.randn(1, max_len,
                                                      num_hiddens))

    def forward(self, tokens, segments, valid_lens):
        X = self.token_embedding(tokens) + self.segment_embedding(segments)

        X = X + self.pos_embedding.data[:, :X.shape[1], :]

        for blk in self.blocks:
            X = blk(X, valid_lens)
        return X


class NSP (nn.Module):
    """ About the simplest possible network
    This class handles the Next Sentence Prediction
    output layer
    """

    def __init__(self, input_size):
        super(NSP, self).__init__()

        self.ff = nn.Linear(input_size, 2)
        # Outputs percentages

    def forward(self, x):
        return self.ff(x)


class MLM (nn.Module):
    def __init__(
        self,
        vocab_size,
        hiddens,
        num_inputs=768,
        **kwargs
    ):
        super(MLM, self).__init__()

        # Will be applied sequentially
        self.ff1 = nn.Linear(num_inputs, hiddens)
        self.relu = nn.ReLU()
        self.norm = nn.LayerNorm(hiddens)
        self.ff2 = nn.Linear(hiddens, vocab_size)

    def forward(self, x, preds):
        # dim 0 is batch size so dim 1 is the actual values
        num_preds = preds.size(1)
        # Flattens pred positions
        positions = preds.reshape(-1)
        N = x.size(0)  # batch size

        # batch idx is a bit confusing but the idea is that you would have
        # an array of (0... batch_size) each number in the sequence is repeated
        # num_predds times using repeat_interleave
        batch_idx = torch.arange(0, N)
        batch_idx = batch_idx.repeat_interleave(num_preds)

        x_mask = (x[batch_idx, positions]
                  .reshape((N, num_preds, -1))
                  )

        out = (
            self.ff2(
                self.norm(
                    self.relu(
                        self.ff1(x_mask)
                    ))))  # applied layers sequentially

        return out


class BERT (nn.Module):
    def __init__(
        self, vocab_size, num_hiddens,
        norm_shape, ffn_num_input,
        ffn_num_hiddens, num_heads,
        num_layers, dropout,
        max_len=1000, key_size=768,
        query_size=768, value_size=768,
        hidden_in_features=768,
        mlm_in_features=768,
        nsp_in_features=768
    ):
        super(BERT, self).__init__()

        self.encoder = Encoder(
            vocab_size, num_hiddens, norm_shape,
            ffn_num_input, ffn_num_hiddens, num_heads,
            num_layers, dropout,
            max_len=max_len,
            key_size=key_size,
            query_size=query_size,
            value_size=value_size
        )

        self.hidden = nn.Sequential(
            nn.Linear(hidden_in_features, num_hiddens),
            nn.Tanh()  # hyperbolic tangent activation function
        )

        # Output layers for mlm and nsp
        self.nsp = NSP(nsp_in_features)
        self.mlm = MLM(vocab_size, num_hiddens, mlm_in_features)

    def forward(
        self,
        tokens,
        segments,
        valid_lens=None,
        pred_positions=None
    ):
        X = self.encoder(tokens, segments, valid_lens)

        if not pred_positions is None:
            mlm_pred = self.mlm(X, pred_positions)

        else:
            mlm_pred = None

        nsp_pred = self.nsp(
            self.hidden(
                X[:, 0, :]
                # 0 index is cls token
            )
        )

        return X, mlm_pred, nsp_pred


In [None]:
MODEL_NAME = 'bert_med'

bert_models_params = {
    'tiny_bert': {
        'vocab_size':len(assembler.vocab),
        'num_hiddens':128,
        'norm_shape':[128],
        'ffn_num_input':128,
        'ffn_num_hiddens':256,
        'num_heads':2,   
        'num_layers':2,
        'dropout':0.2,
        'key_size':128,
        'query_size':128,
        'value_size':128,
        'hidden_in_features':128,
        'mlm_in_features':128,
        'nsp_in_features':128
    },

    'bert_med': {
        'vocab_size':len(assembler.vocab),
        'num_hiddens':128,
        'norm_shape':[128],
        'ffn_num_input':128,
        'ffn_num_hiddens':256,
        'num_heads':8,   
        'num_layers':4,
        'dropout':0.2,
        'key_size':128,
        'query_size':128,
        'value_size':128,
        'hidden_in_features':128,
        'mlm_in_features':128,
        'nsp_in_features':128
    }
}

# Allows model choice from bert_models_param dict
# Input chosen model into MODEL_NAME constant
model = BERT(**bert_models_params[MODEL_NAME])


for (
        tokens_X,
        segments_X,
        valid_lens_x,
        pred_positions_X,
        mlm_weights_X, mlm_Y,
        nsp_y)\
        in trainloader:
    print(tokens_X.shape, segments_X.shape, valid_lens_x.shape,
          pred_positions_X.shape, mlm_weights_X.shape, mlm_Y.shape,
          nsp_y.shape)

    out = model(tokens_X, segments_X,
               valid_lens_x.reshape(-1), pred_positions_X)
    for obj in out:
        print(obj.shape)

    break
# POGGERS IT WOKRS
# YAYAYAYAYYAYAYYYYYYYYYYYYYYYYYYY


In [None]:
model

In [None]:
device = torch.device('cuda' if torch.cuda.is_available else 'cpu')
optimizer = torch.optim.Adam(model.parameters(), lr = 0.0008) # origionally 0.001
criterion = nn.CrossEntropyLoss()
model.to(device)

max_norm = None

print (f'BERT Model loaded with device: {device}')

In [None]:
# WANDB BLOCK
wandb_params = bert_models_params[MODEL_NAME].copy()
wandb_params['Model Name']= MODEL_NAME
wandb.init(project="BERT", entity="cowsarecool", config=wandb_params)
wandb.watch(model)

In [None]:
# Ensures a save location for runnning models
save_dir = os.path.join (running_dir, 'model_saves', MODEL_NAME)
if not os.path.isdir (save_dir):
    os.makedirs (save_dir)

In [None]:
# Load model from save
MODEL_SAVE_EPOCH = -1


save_dict_path = os.path.join(save_dir, 'epoch_{}'.format(MODEL_SAVE_EPOCH))
if os.path.isfile (save_dict_path):
    save_state_dicts = torch.load(save_dict_path)
    model.load_state_dict(save_state_dicts['model_state'])

else:
    print ('file does not exist, or is unspecified')

In [None]:
### TRAIN TESTING
test_iterations = 100 # <<<<--------- TIM

print_every = 5
save_every = 30 
eval_every = 10


runtimes = list()
    # with alive_bar(test_iterations+1, ## Only works in a normal python script
    #                        title='Training', bar='smooth',
    #                        length=75) as bar:
try:

    for epoch in tqdm_notebook(range(test_iterations+1), desc='Training'):
        cumulative_losses, nsp_losses, mlm_losses = [], [], []
        nsp_accuracies, mlm_accuracies = [], []
        start_time = time.perf_counter()
        
        model.train()
        for batch_num, (
            tokens_X,
            segments_X,
            valid_lens_x,
            pred_positions_X,
            mlm_weights_X, mlm_Y,
            nsp_y)\
        in enumerate(trainloader):
            # print ('\n'.join ([str(data.shape) for data in [tokens_X, segments_X, valid_lens_x, pred_positions_X, mlm_weights_X, mlm_Y, nsp_y]]))
            
            # Sending absolutely everything to gpu 
            # Unfortunately there is no way to do this with a loop
            # so we're stuck manually sending each variable to device
            tokens_X, segments_X,\
            valid_lens_x, pred_positions_X,\
            mlm_weights_X, mlm_Y, nsp_y =(
            tokens_X.to(device), segments_X.to(device),
            valid_lens_x.to(device), pred_positions_X.to(device),
            mlm_weights_X.to(device), mlm_Y.to(device), nsp_y.to(device))

            optimizer.zero_grad()

            X, mlm_y_hat, nsp_y_hat =  model(tokens_X, segments_X,
                valid_lens_x.reshape(-1), pred_positions_X)
            
            # nsp loss calculation
            nsp_loss = criterion (nsp_y_hat, nsp_y)

            # mlm loss is a bit more complicated; it involves
            # mlm weights, a one hot encoded vector specifying which 
            # tokens in the mlm input are <pad>. This is used to factor
            # padding out of the loss calculation. Another complication 
            # is the multidimensionality of the mlm output, it just needs
            # flattening though before going through the loss funciton (criterion)
            mlm_loss = criterion (
                mlm_y_hat.reshape(-1, len(trainloader.dataset.vocab)),
                mlm_Y.reshape(-1)
            ) * mlm_weights_X.reshape(-1, 1) # factors out padding
            mlm_loss = mlm_loss.sum()/(mlm_weights_X.sum()+1e-8)

            cumulative_loss = mlm_loss + nsp_loss

            cumulative_losses.append (cumulative_loss.item()) 
            mlm_losses.append (mlm_loss.item())
            nsp_losses.append (nsp_loss.item())

            nsp_accuracy = torch.sum(
                torch.argmax(nsp_y_hat, dim=1) == nsp_y
            )/len(nsp_y_hat)

            mlm_accuracy = torch.sum(
                torch.argmax(mlm_y_hat, dim=2) == mlm_Y
            ) / np.product(list(mlm_Y.shape))       

            nsp_accuracies.append (nsp_accuracy.item())
            mlm_accuracies.append (mlm_accuracy.item())

            # print (f'nsp y_hat: {nsp_y_hat.shape}, nsp_y: {nsp_y.shape}')
            # print (f'mlm y_hat: {mlm_y_hat.shape}, mlm_y: {mlm_Y.shape}')
            if max_norm is not None:
                nn.utils.clip_grad_norm_(model.parameters(), max_norm)

            # cumulative_loss.backward()
            mlm_loss.backward()
            optimizer.step()

            wandb.log ({
                'mlm_loss':mlm_loss,
                'nsp_loss':nsp_loss,
                'cumulative_loss':cumulative_loss,
                'nsp_accuracy':nsp_accuracy,
                'mlm_accuracy':mlm_accuracy
            })
            
            # Unfortunately Alivebar doesnt work in notebooks
            # Freaking Sucks
            # bar.text(
            #     f'Batch Accuracy: {mlm_accuracy:.3f}\t Batch Num: {batch_num}')
            
        runtimes.append(time.perf_counter()-start_time)
        if (epoch % print_every) ==0:
            (print (
                'Epoch: {}\t Cum Loss: {:.3f}, MLM Loss: {:.3f}, NSP Loss: {:.3f}\t \
MLM Accuracy: {:.3f}, NSP Accuracy {:.3f}\t \
Avg Epoch Runtime: {:.3f}'
            .format(*[
                np.mean(data) for data in [
                    epoch, 
                    cumulative_losses, 
                    mlm_losses, 
                    nsp_losses, 
                    mlm_accuracies, 
                    nsp_accuracies, 
                    runtimes
                    ]
                ]
            )))

            runtimes = list()
        
        if (epoch % save_every)==0:
            model_state = model.state_dict()
            optimizer_state = optimizer.state_dict()

            torch.save({
                'model_state':model_state,
                'optimizer_state':optimizer_state
                }, os.path.join(save_dir, f'epoch_{epoch}'))

        if (epoch % eval_every == 0):
            if validloader is None:
                # if validloader contains nothing
                # dont evaluate
                continue

            model.eval()
            with torch.no_grad():
                valid_cumulative_losses, valid_nsp_accuracies, valid_mlm_accuracies = [], [], []

                for (
                    tokens_X,
                    segments_X,
                    valid_lens_x,
                    pred_positions_X,
                    mlm_weights_X, mlm_Y,
                    nsp_y)\
                in validloader:
                    tokens_X, segments_X,\
                    valid_lens_x, pred_positions_X,\
                    mlm_weights_X, mlm_Y, nsp_y =(
                    tokens_X.to(device), segments_X.to(device),
                    valid_lens_x.to(device), pred_positions_X.to(device),
                    mlm_weights_X.to(device), mlm_Y.to(device), nsp_y.to(device))

                    X, mlm_y_hat, nsp_y_hat =  model(tokens_X, segments_X,
                        valid_lens_x.reshape(-1), pred_positions_X)

                    nsp_loss = criterion (nsp_y_hat, nsp_y)
                    mlm_loss = criterion (
                        mlm_y_hat.reshape(-1, len(trainloader.dataset.vocab)),
                        mlm_Y.reshape(-1)
                    ) * mlm_weights_X.reshape(-1, 1) # factors out padding
                    mlm_loss = mlm_loss.sum()/(mlm_weights_X.sum()+1e-8)

                    cumulative_loss = mlm_loss + nsp_loss

                    nsp_accuracy = torch.sum(
                        torch.argmax(nsp_y_hat, dim=1) == nsp_y
                    )/len(nsp_y_hat)

                    mlm_accuracy = torch.sum(
                        torch.argmax(mlm_y_hat, dim=2) == mlm_Y
                    )/ np.product(list(mlm_Y.shape))           

                    valid_cumulative_losses.append(cumulative_loss.item())
                    valid_nsp_accuracies.append(nsp_accuracy.item())
                    valid_mlm_accuracies.append(mlm_accuracy.item())

                wandb.log ({
                    'valid_cumulative_loss':np.mean(valid_cumulative_losses),
                    'valid_nsp_accuracy':np.mean(valid_nsp_accuracies),
                    'valid_mlm_accuracy':np.mean(valid_mlm_accuracies)
                })

                print (
                    'Eval Epoch: {}\t MLM Acc: {}, NSP Acc: {}, Cum Loss: {}'
                    .format(*
                        [
                            np.round(np.mean(data), decimals=2)
                            for data in [epoch, valid_mlm_accuracies, valid_nsp_accuracies, valid_cumulative_losses]
                        ]
                    )
                )
except KeyboardInterrupt:
    print ("Keyboard Interrupt Detected\n Ending Run")

except RuntimeError as e:
    if (' '.join(str(e).split(' ')[:4])) != 'CUDA out of memory.':
        raise e 
    
    else:
        print ('CUDA NEEDS MORE FREAKING MEMORY\nProbably about {} more'.format(''.join(str(e).split(' ')[7:9])))

else:
    print ("Program Finished Successfully")

finally:
    wandb.finish()



In [None]:
def example_encode (text):
    pass # TODO: write

In [None]:
def demonstration (model, validloader):
    model.eval()
    for (tokens_X,
         segments_X,
         valid_lens_x,
         pred_positions_X,
         mlm_weights_X, mlm_Y,
         nsp_y) in validloader:
        
        tokens_X, segments_X,\
        valid_lens_x, pred_positions_X,\
        mlm_weights_X, mlm_Y, nsp_y =(
            tokens_X.to(device), segments_X.to(device),
            valid_lens_x.to(device), pred_positions_X.to(device),
            mlm_weights_X.to(device), mlm_Y.to(device), nsp_y.to(device)
            )

        _, mlm_y_hat, nsp_y_hat =  model(tokens_X, segments_X,
            valid_lens_x.reshape(-1), pred_positions_X)

        vector_itos = np.vectorize (lambda x: validloader.dataset.vocab.itos[x])
        cpu_device = torch.device('cpu')
        segments_X, tokens_X, mlm_y_hat, mlm_Y = (segments_X.to(cpu_device), 
            tokens_X.to(cpu_device), mlm_y_hat.to(cpu_device), 
            mlm_Y.to(cpu_device))

        nsp_data = pd.DataFrame(
            {
                'input': [vector_itos(tokenized_sentence) for tokenized_sentence in segments_X],
                'output':list(nsp_y_hat),
                'ground truth':list (nsp_y)
            }
        )

        mlm_data = pd.DataFrame(
            {
                'input': [vector_itos(tokenized_sentence) for tokenized_sentence in tokens_X],
                'output':[
                    ', '.join(
                        list(
                            vector_itos(tokenized_pred_words.argmax(dim=1))
                        )
                    ) for tokenized_pred_words in mlm_y_hat
                ],

                'ground truth': [
                    ', '.join(
                        list(vector_itos(tokenized_words))
                    ) for tokenized_words in mlm_Y
                ]
            }
        )

        break

    return nsp_data, mlm_data

nsp_data, mlm_data = demonstration(model, trainloader)

In [None]:
mlm_data[['output', 'ground truth']].iloc[4]

In [None]:
wandb.log({
    'nsp_predictions': wandb.Table (nsp_data),
    'mlm_predictions': wandb.Table (mlm_data)
})


In [None]:
def get_full_class_name(obj):
    # taken from http://stackoverflow.com/
    module = obj.__class__.__module__
    if module is None or module == str.__class__.__module__:
        return obj.__class__.__name__
    return module + '.' + obj.__class__.__name__
