In [1]:
import re
import torch
import copy
import math
import time
import torch.nn as nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Reading, preprocessing and encoding text

In [2]:
class Vocab:
    """
    Simple vocabulary implementation.
    
    vocab = Vocab(text)
    vocab(index) -> Character
    vocab[Character] -> index
    """
    def __init__(self, text):
        self.symbols  = set(text)
        self.c2i = { c : i for i, c in enumerate(self.symbols) }
        self.i2c = { i : c for i, c in enumerate(self.symbols) }

    def __getitem__(self, char):
        return self.c2i[char]

    def __char_by_index__(self, index):
        return self.i2c[index]

    def __len__(self):
        return len(self.i2c)

    __call__ = __char_by_index__


class TextDataset:
    """
    Helper class for working with raw input text.
    Has methods for removing unnecessary characters and for encoding characters.
    """
    def __init__(self, sentences):
        self._replacements = {
            'NL' : lambda x: re.compile('\n').sub('', x),
            '?'  : lambda x: re.compile('\?').sub('.', x),
            '!'  : lambda x: re.compile('!').sub('.', x),
            '('  : lambda x: re.compile('\(').sub('', x),
            ')'  : lambda x: re.compile('\)').sub('', x),
            '"'  : lambda x: re.compile('"').sub('', x),
            '-'  : lambda x: re.compile('\-').sub(' ', x),
            ';'  : lambda x: re.compile(';').sub('.', x),
            'DS'  : lambda x: re.compile(r'[\.]+').sub('.', x)
        }
        self.replace_bad_chars = lambda x: re.compile(r'[^а-я,\.\s]+').sub('', x)
        self.text   = self.preprocess(sentences)
        self.vocab  = Vocab(self.text)

    def replace(self, sentence):
        for replacement in self._replacements.values():
            sentence = replacement(sentence)

        return sentence
    
    def preprocess(self, sentences):
        """
        Preprocesses text: changes text register to lower case, removes newline characters,
        removes non-alphabet characters, translates punctuation marks to dots and commas.
        """
        text = map(self.replace, sentences)
        text = map(str.lower, text)
        text = map(self.replace_bad_chars, text)
        text = map(str.strip, text)
        text = filter(lambda t: t, text)
        text = ' '.join(text)
        return text

    @property
    def vocab_size(self):
        """
        Returns length of associated vocab
        """
        return len(self.vocab)

    def encode(self):
        """
        Encodes processed text using pre-built vocabulary.
        """
        self.text = torch.tensor([self.vocab[char] for char in self.text])

In [3]:
def read(path):
    with open(path, 'r', encoding='utf8') as f:
        return f.readlines()

## Batching data

Given a 1-D vector of sequential data, ``batchify()`` arranges the data
into ``batch_size`` columns. If the data does not divide evenly into
``batch_size`` columns, then the data is trimmed to fit.

\begin{align}\begin{bmatrix}
  \text{A} & \text{B} & \text{C} & \ldots & \text{X} & \text{Y} & \text{Z}
  \end{bmatrix}
  \Rightarrow
  \begin{bmatrix}
  \begin{bmatrix}\text{A} \\ \text{B} \\ \text{C} \\ \text{D} \\ \text{E} \\ \text{F}\end{bmatrix} &
  \begin{bmatrix}\text{G} \\ \text{H} \\ \text{I} \\ \text{J} \\ \text{K} \\ \text{L}\end{bmatrix} &
  \begin{bmatrix}\text{M} \\ \text{N} \\ \text{O} \\ \text{P} \\ \text{Q} \\ \text{R}\end{bmatrix} &
  \begin{bmatrix}\text{S} \\ \text{T} \\ \text{U} \\ \text{V} \\ \text{W} \\ \text{X}\end{bmatrix}
  \end{bmatrix}\end{align}

```get_batch()``` generates a pair of input-target sequences for the transformer model. It subdivides the source data into chunks of length bptt.

\begin{align}
  \begin{bmatrix}
  \begin{bmatrix}\text{A} \\ \text{B} \\ \text{C} \\ \text{D} \\ \text{E} \\ \text{F}\end{bmatrix} &
  \begin{bmatrix}\text{G} \\ \text{H} \\ \text{I} \\ \text{J} \\ \text{K} \\ \text{L}\end{bmatrix} &
  \begin{bmatrix}\text{M} \\ \text{N} \\ \text{O} \\ \text{P} \\ \text{Q} \\ \text{R}\end{bmatrix} &
  \begin{bmatrix}\text{S} \\ \text{T} \\ \text{U} \\ \text{V} \\ \text{W} \\ \text{X}\end{bmatrix}
  \end{bmatrix}
  \Rightarrow
  \begin{bmatrix}
  \begin{bmatrix}\text{A} & \text{G} & \text{M} & \text{S}\end{bmatrix} \\
  \begin{bmatrix}\text{B} & \text{H} & \text{N} & \text{T}\end{bmatrix}
  \end{bmatrix}
  \begin{bmatrix}
  \begin{bmatrix}\text{B} & \text{H} & \text{N} & \text{T}\end{bmatrix} \\
  \begin{bmatrix}\text{C} & \text{I} & \text{O} & \text{U}\end{bmatrix}
  \end{bmatrix}
  \end{align}

In [4]:
def batchify(data, bsz):
    """
    Divides the data into bsz separate sequences, removing extra elements
    that wouldn't cleanly fit.

    (data: Tensor[N], batch_size) -> (data: Tensor[N // batch_size, batch_size])
    """
    seq_len = data.size(0) // bsz
    data = data[:seq_len * bsz]
    data = data.view(bsz, seq_len).t().contiguous()
    return data

bptt = 35
def get_batch(source, i):
    """
    (Tensor[full_seq_len, batch_size], int) -> (Tensor[seq_len, batch_size], Tensor[seq_len, batch_size])
    """
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

desktop_path = 'C:/Users/Flexatroid/Desktop/AdvancedML/Transformers/text.txt'
drive_path = '/content/drive/MyDrive/AdvancedML/text.txt'
raw_text = read(drive_path)

textDataset = TextDataset(raw_text)
textDataset.encode()

batch_size = 20
eval_batch_size = 10
val_size=0.5
n = int(len(textDataset.text) // batch_size * val_size)

train_data = batchify(textDataset.text[:-2*n], batch_size).to(device)
val_data   = batchify(textDataset.text[-2*n:-n], eval_batch_size).to(device)
test_data  = batchify(textDataset.text[-n:], eval_batch_size).to(device)

print(train_data.size(), val_data.size(), test_data.size())

torch.Size([41949, 20]) torch.Size([2207, 10]) torch.Size([2207, 10])


## Transformer model & Positional encoding

In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        self.encoder     = TransformerEncoder(
            encoder_layer=TransformerEncoderLayer(ninp, nhead, nhid, dropout), 
            num_layers=nlayers
        )
        self.embs = nn.Embedding(ntoken, ninp)
        self.decoder = nn.Linear(ninp, ntoken)

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, x, src_mask):
        x = self.embs(x)
        x = self.pos_encoder(x)
        x = self.encoder(x, src_mask)
        x = self.decoder(x)
        return x

def generate_square_subsequent_mask(sz):
    return torch.triu(torch.ones(sz, sz) * float('-inf'), diagonal=1)

In [14]:
ntokens = textDataset.vocab_size 
# emsize  = 512 
# nhid    = 256 
# nlayers = 5 
# nhead   = 8 
# dropout = 0.2
emsize = 512  # embedding dimension
nhid = 256  # dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 5  # number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 8  # number of heads in nn.MultiheadAttention
dropout = 0.2  # dropout probability
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

## Model training
Using CrossEntropyLoss with the SGD (stochastic gradient descent) optimizer. Learning rate is initially set to 0.0001 and follows a StepLR schedule. During training, nn.utils.clip_grad_norm_ is used to prevent gradients from exploding.

In [22]:
criterion = nn.CrossEntropyLoss()
lr = 0.0001
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, gamma=0.9)

def train():
    model.train() 
    total_loss = 0.
    start_time = time.time()
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 100
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:5.5f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval()
    total_loss = 0.
    src_mask = generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [23]:
best_val_loss = float("inf")
epochs = 5 
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

test_loss = evaluate(best_model, test_data)
test_ppl = math.exp(test_loss)
print('=' * 89)
print(f'| End of training | test loss {test_loss:5.2f} | '
      f'test ppl {test_ppl:8.2f}')
print('=' * 89)

| epoch   1 |   100/ 1198 batches | lr 0.00010 | ms/batch 43.82 | loss  2.83 | ppl    16.99
| epoch   1 |   200/ 1198 batches | lr 0.00010 | ms/batch 41.59 | loss  2.57 | ppl    13.01
| epoch   1 |   300/ 1198 batches | lr 0.00010 | ms/batch 41.68 | loss  2.51 | ppl    12.34
| epoch   1 |   400/ 1198 batches | lr 0.00010 | ms/batch 41.46 | loss  2.45 | ppl    11.62
| epoch   1 |   500/ 1198 batches | lr 0.00010 | ms/batch 42.03 | loss  2.41 | ppl    11.15
| epoch   1 |   600/ 1198 batches | lr 0.00010 | ms/batch 42.11 | loss  2.36 | ppl    10.54
| epoch   1 |   700/ 1198 batches | lr 0.00010 | ms/batch 41.28 | loss  2.30 | ppl    10.00
| epoch   1 |   800/ 1198 batches | lr 0.00010 | ms/batch 41.51 | loss  2.28 | ppl     9.79
| epoch   1 |   900/ 1198 batches | lr 0.00010 | ms/batch 41.96 | loss  2.24 | ppl     9.39
| epoch   1 |  1000/ 1198 batches | lr 0.00010 | ms/batch 41.76 | loss  2.20 | ppl     9.05
| epoch   1 |  1100/ 1198 batches | lr 0.00010 | ms/batch 41.51 | loss  2.18 | p

  ## Text generation examples

In [8]:
def gen_char(text, temperature=0):
    d = torch.tensor([textDataset.vocab[c] for c in text]).unsqueeze(1).to(device)
    m = generate_square_subsequent_mask(len(text)).to(device)
    with torch.no_grad():
        outputs = best_model(d, m)[-1][0]
        if temperature != 0:
            outputs = nn.Softmax(0)(outputs / temperature)
            idx = torch.multinomial(outputs, 1).item()
        else:
            idx = outputs.argmax().item()
        return textDataset.vocab(idx)

def generate(text, n, temperature=0):
    for _ in range(n):
        text += gen_char(text[-eval_batch_size:], temperature)

    return text

In [16]:
best_model = model.eval()
best_model.load_state_dict(torch.load('/content/drive/MyDrive/AdvancedML/tf_model_final.pt'))  

<All keys matched successfully>

In [17]:
generate('фродо взял кол', 200, 0)

'фродо взял колзжпогкдхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжхкжзбфх злзй,покчхжх'

In [None]:
generate('фродо взял кол', 200, 0.3)

'фродо взял кольцо. я не возвращался и под в нем положил на восток с ним не было от было сображил старый своим своем становился на деревья по не последний в сторону и обеспокойной в от поднялся в стальный собой с по'

In [None]:
generate('гендальф побежал', 400, 0.3)

'гендальф побежал он пригорье. я не подняли на кольцо все собой собой страннных в подолжны вернули в помощь на не по от вернулся в под ними, что они поднимали и по не собой строными не поднялась в этом старый в от не собраться и показался в собой в западной воды. они поднялись в собой собенной в последников следующий в пришли по видели на следующий пришли и на стал он по высокой старый последники на запад не могут'

In [40]:
torch.save(best_model.state_dict(), '/content/drive/MyDrive/AdvancedML/tf_model_final.pt')  