In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.chdir('/content/drive/MyDrive/Colab Notebooks/Transformer/')

In [None]:
!pip install pytorch-crf



In [None]:
import torchtext
import torch
from torchcrf import CRF
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from collections import Counter
from torchtext.vocab import Vocab
import io
import re
from typing import Iterable, List

In [None]:
SRC_LANGUAGE = 'word'
TGT_LANGUAGE = 'tag'

In [None]:
token_transform={}
def word_tokenizer(line):
  return [wt.rsplit('/',1)[0] for wt in line.rstrip('\n').split(' ')]
def tag_tokenizer(line):
  return [wt.rsplit('/',1)[1] for wt in line.rstrip('\n').split(' ')]
token_transform[SRC_LANGUAGE]=word_tokenizer
token_transform[TGT_LANGUAGE]=tag_tokenizer

In [None]:
def yield_tokens(filepath,ln):
  language_index={'word':0,'tag':1}
  with io.open(filepath, encoding="utf8") as f:
    for line in f.readlines():
      yield token_transform[ln](line)

In [None]:
# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']

In [None]:
vocab_transform={}
filepath='./trainset.txt'
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    # Create torchtext's Vocab object
    vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(filepath, ln),
                            min_freq=1,
                            specials=special_symbols,
                            special_first=True)

# Set UNK_IDX as the default index. This index is returned when the token is not found.
# If not set, it throws RuntimeError when the queried token is not found in the Vocabulary.
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
  vocab_transform[ln].set_default_index(UNK_IDX)

In [None]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                        nhead=nhead,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers,
                        dim_feedforward=dim_feedforward,
                        dropout=dropout)
        self.crf=CRF(tgt_vocab_size)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        emission=self.generator(outs)
        return emission

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

In [None]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

In [None]:
torch.manual_seed(1)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 2
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 2

transformer = Seq2SeqTransformer(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]),
                      torch.tensor(token_ids),
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                          vocab_transform[ln], #Numericalization
                          tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch, tgtinput_batch = [], [], []
    for sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](sample.rstrip("\n")))
        curtgt=text_transform[TGT_LANGUAGE](sample.rstrip("\n"))
        tgt_batch.append(curtgt)
        seqlen=curtgt.shape[0]
        tgtinput_batch.append(torch.tensor([BOS_IDX]*seqlen))


    src_batch = pad_sequence(src_batch, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
    tgtinput_batch = pad_sequence(tgtinput_batch, padding_value=PAD_IDX)
    return src_batch, tgt_batch,tgtinput_batch

In [None]:
TRAINPATH='./trainset.txt'
VALPATH='./valset.txt'

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            上次验证集损失值改善后等待几个epoch
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            如果是True，为每个验证集损失值改善打印一条信息
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            监测数量的最小变化，以符合改进的要求
                            Default: 0
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''
        Saves model when validation loss decrease.
        验证损失减少时保存模型。
        '''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        # torch.save(model.state_dict(), 'checkpoint.pt')     # 这里会存储迄今最优模型的参数
        torch.save(model, 'checkpoint.pt')                 # 这里会存储迄今最优的模型
        self.val_loss_min = val_loss

In [None]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    f=io.open(TRAINPATH, encoding="utf8")
    train_iter=f.readlines()
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt, tgt_input in train_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt_input.to(DEVICE)

        tgt_input = tgt[:-1, :] #tgt_input[:-1,:] #

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        optimizer.zero_grad()

        tgt_out = tgt[1:, :]
        _, tgt_mask, _, tgt_padding_mask = create_mask(src, tgt_out)
        tgt_padding_mask=(~tgt_padding_mask).transpose(0, 1).type(torch.uint8)
        loss=-model.crf(logits,tgt_out,mask=tgt_padding_mask)

        # loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0
    f=io.open(VALPATH, encoding="utf8")
    val_iter=f.readlines()
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt, tgt_input in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt_input.to(DEVICE)

        tgt_input = tgt[:-1, :] # tgt_input[:-1,:] #

        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)

        tgt_out = tgt[1:, :]

        _, tgt_mask, _, tgt_padding_mask = create_mask(src, tgt_out)
        tgt_padding_mask=(~tgt_padding_mask).transpose(0, 1).type(torch.uint8)
        loss=-model.crf(logits,tgt_out,mask=tgt_padding_mask)

        # loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

In [None]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, transformer.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])

In [None]:
params

27361780

In [None]:
import time
from timeit import default_timer as timer
NUM_EPOCHS = 18
patience=3
early_stopping = EarlyStopping(patience=patience, verbose=True)
start=time.time()
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(transformer)

    early_stopping(val_loss, transformer)
    if early_stopping.early_stop:
      print("Early stopping")
      break
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
end=time.time()
transformer=torch.load('checkpoint.pt')
traintime=end-start
traintime

Validation loss decreased (inf --> 1950.334077).  Saving model ...
Epoch: 1, Train loss: 6128.863, Val loss: 1950.334, Epoch time = 45.896s
Validation loss decreased (1950.334077 --> 693.494476).  Saving model ...
Epoch: 2, Train loss: 1726.247, Val loss: 693.494, Epoch time = 44.411s
Validation loss decreased (693.494476 --> 512.690792).  Saving model ...
Epoch: 3, Train loss: 810.242, Val loss: 512.691, Epoch time = 44.521s
Validation loss decreased (512.690792 --> 463.232431).  Saving model ...
Epoch: 4, Train loss: 550.888, Val loss: 463.232, Epoch time = 44.179s
Validation loss decreased (463.232431 --> 441.909554).  Saving model ...
Epoch: 5, Train loss: 428.677, Val loss: 441.910, Epoch time = 44.167s
EarlyStopping counter: 1 out of 3
Epoch: 6, Train loss: 351.130, Val loss: 447.185, Epoch time = 44.100s
Validation loss decreased (441.909554 --> 423.246785).  Saving model ...
Epoch: 7, Train loss: 302.581, Val loss: 423.247, Epoch time = 43.754s
EarlyStopping counter: 1 out of 3

490.3692970275879

In [None]:
# function to generate output sequence using greedy algorithm
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask = src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys_ini = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    ys=ys_ini
    # res = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    # for i in range(max_len-1):
    for i in range((src.shape[0])):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0))
                    .type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        # out = out.transpose(0, 1)
        prob = model.generator(out)
        prediction = model.crf.decode(prob)
        ys = torch.cat([ys_ini,torch.tensor(prediction).transpose(0,1).to(DEVICE)],dim=0)
        # next_word=prediction.item()
        # _, next_word = torch.max(prob, dim=1)
        # next_word = next_word.item()

        # ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        # res = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        # if next_word == EOS_IDX:
        #     break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    tgt_tokens = greedy_decode(
        model,  src, src_mask, max_len=num_tokens + 5, start_symbol=BOS_IDX).flatten()
    # print(" ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))))
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [None]:
TESTPATH='./sents.answer'
f=io.open(TESTPATH, encoding="utf8")
test_iter=f.readlines()
test_dataloader = DataLoader(test_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

In [None]:
truetag=[]
for line in test_iter:
  truetag.append(token_transform[TGT_LANGUAGE](line))


In [None]:
start=time.time()
predtag=[]
i=0
for line in test_iter:
  print(f'{i}/{len(test_iter)}')
  i+=1
  tmpli=translate(transformer, line).split(' ')
  if tmpli[0]=='':
    tmpli=tmpli[1:]
  if tmpli[-1]=='':
    tmpli=tmpli[:-1]
  predtag.append(tmpli)
end=time.time()
predtime=end-start

In [None]:
predtime

336.0183491706848

In [None]:
from sklearn.metrics import accuracy_score

In [None]:
correctnum=0
allnum=0
for i in range(len(truetag)):
  try:
    correctnum+=accuracy_score(y_true=truetag[i],y_pred=predtag[i][:len(truetag[i])])*len(truetag[i])
    allnum+=len(truetag[i])
  except:
    print(i)
    continue

In [None]:
correctnum/allnum

0.957865345453782

#### Transformer+CRF
NHead | Encoder layer | Decoder layer | Number of Param. | Train time | Predict time | Accuracy
--- | --- | --- | --- | --- | --- | ---
1 | 1 | 1 | 24732148 | 401.999s | 254.1364s | 95.28%
2 | 1 | 1 | 24732148 | 405.998s | 256.352s | 95.73%
4 | 1 | 1 | 24732148 | 323.886s | 257.431s | 95.47%
1 | 2 | 1 | 26310132 | 450.551s | 254.890s | 94.93%
1 | 4 | 1 | 29466100 | 602.610s | 258.773s | 94.85%
1 | 1 | 2 | 27361780 | 530.378s | 327.641s | 95.56%
1 | 1 | 4 | 32621044 | 708.147s | 479.063s | 95.88%


NHead | Encoder layer | Decoder layer | Number of Param. | Train time | Predict time | Accuracy
--- | --- | --- | --- | --- | --- | ---
1 | 1 | 1 | 24729649 | 281.136s | 88.057s | 95.02%
2 | 1 | 1 | 24729649 | 203.541s | 91.099s | 95.42%
4 | 1 | 1 | 24729649 | 281.731s | 88.984s | 95.51%
1 | 2 | 1 | 26307633 | 292.877s | 93.071s | 95.12%
1 | 4 | 1 | 29463601 | 430.538s | 92.004s | 94.63%
1 | 1 | 2 | 27359281 | 414.209s | 144.412s | 95.66%
1 | 1 | 4 | 32618545 | 755.313s | 251.045s | 95.81%

In [None]:
def argmax(vec):
    # return the argmax as a python int
    _, idx = torch.max(vec, 1)
    return idx


def prepare_sequence(seq, to_ix):
    idxs = [to_ix[w] for w in seq]
    return torch.tensor(idxs, dtype=torch.long)


# Compute log sum exp in a numerically stable way for the forward algorithm
def log_sum_exp(vec):
    # in: [batch,tgt_size]
    # out: [batch]
    max_score,_=torch.max(vec,1)
    return max_score+torch.log(torch.sum(torch.exp(vec-max_score.view(vec.shape[0],-1)),1))

    # max_score = vec[:, argmax(vec)]
    # max_score_broadcast = max_score.view(1, -1).expand(vec.size())
    # return max_score + \
    #     torch.log(torch.sum(torch.exp(vec - max_score_broadcast)))


class Transformer_CRF(nn.Module):

    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Transformer_CRF, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                        nhead=nhead,
                        num_encoder_layers=num_encoder_layers,
                        num_decoder_layers=num_decoder_layers,
                        dim_feedforward=dim_feedforward,
                        dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)
        
        self.transitions = nn.Parameter(
            torch.randn(tgt_vocab_size, tgt_vocab_size))
        
        self.transitions.data[BOS_IDX, :] = -10000
        self.transitions.data[:, EOS_IDX] = -10000
        self.tagset_size=tgt_vocab_size

    def transformer_forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None,
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(
                            self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(
                          self.tgt_tok_emb(tgt)), memory,
                          tgt_mask)

    def _forward_alg(self, feats,tags):
        # feats: [seq_len,batch,hidden]
        # tags: [seq_len,batch]
        # Do the forward algorithm to compute the partition function
        batchsize=tags.shape[1]
        alpha= torch.zeros(batchsize).to(DEVICE)

        init_alphas = torch.full((batchsize, self.tagset_size), -10000.)
        # START_TAG has all of the score.
        init_alphas[:,BOS_IDX] = 0.

        # Wrap in a variable so that we will get automatic backprop
        forward_var = init_alphas
        forward_var=forward_var.to(DEVICE)

        # Iterate through the sentence
        for i in range(feats.shape[0]):
            tag=tags[i]
            notpad=torch.where(tag!=PAD_IDX)[0]
            notpadeos=torch.where((tag!=PAD_IDX) & (tag!=EOS_IDX))[0]
            # if tag==EOS_IDX:
            #   break
            feat=feats[i]
            alphas_t = []  # The forward tensors at this timestep
            for next_tag in range(self.tagset_size):
                # broadcast the emission score: it is the same regardless of
                # the previous tag
                emit_score = feat[:,next_tag].view(
                    batchsize, -1).expand(batchsize, self.tagset_size)
                # the ith entry of trans_score is the score of transitioning to
                # next_tag from i
                trans_score = self.transitions[next_tag].expand(batchsize, self.tagset_size)
                # The ith entry of next_tag_var is the value for the
                # edge (i -> next_tag) before we do log-sum-exp
                next_tag_var = forward_var
                next_tag_var=next_tag_var.index_add_(0, notpad, trans_score[notpad])
                next_tag_var=next_tag_var.index_add_(0, notpadeos, emit_score[notpadeos])

                # next_tag_var = forward_var + trans_score + emit_score
                # The forward variable for this tag is log-sum-exp of all the
                # scores.
                alphas_t.append(log_sum_exp(next_tag_var).view(batchsize))
            forward_var = torch.cat(alphas_t).view(batchsize, -1)
        terminal_var = forward_var
        alpha = log_sum_exp(terminal_var)
        return torch.mean(alpha)

    def _get_transformer_features(self, src,tgt):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        tgt_input = tgt[:-1, :]
        src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)
        transformer_feats = self.transformer_forward(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        return transformer_feats

    def _score_sentence(self, feats, tags):
        # feats: [seq_len,batch,hidden]
        # tags: [seq_len,batch]
        # Gives the score of a provided tag sequence
        score = torch.zeros(tags.shape[1]).to(DEVICE)
        for i, feat in enumerate(feats):
            notpad=torch.where(tags[i+1]!=PAD_IDX)[0]
            notpadeos=torch.where((tags[i+1]!=PAD_IDX) & (tags[i+1]!=EOS_IDX))[0]
            score=score.index_add_(0, notpad, self.transitions[tags[i + 1], tags[i]][notpad])
            score=score.index_add_(0, notpadeos, torch.gather(feat,1,tags[i + 1].unsqueeze(1)).squeeze(1)[notpadeos])
        return torch.mean(score)

    def _viterbi_decode(self, feats):
        backpointers = []

        # Initialize the viterbi variables in log space
        init_vvars = torch.full((1, self.tagset_size), -10000.)
        init_vvars[0][BOS_IDX] = 0

        # forward_var at step i holds the viterbi variables for step i-1
        forward_var = init_vvars
        for feat in feats:
            bptrs_t = []  # holds the backpointers for this step
            viterbivars_t = []  # holds the viterbi variables for this step

            for next_tag in range(self.tagset_size):
                # next_tag_var[i] holds the viterbi variable for tag i at the
                # previous step, plus the score of transitioning
                # from tag i to next_tag.
                # We don't include the emission scores here because the max
                # does not depend on them (we add them in below)
                next_tag_var = forward_var + self.transitions[next_tag]
                best_tag_id = argmax(next_tag_var)
                bptrs_t.append(best_tag_id)
                viterbivars_t.append(next_tag_var[0][best_tag_id].view(1))
            # Now add in the emission scores, and assign forward_var to the set
            # of viterbi variables we just computed
            forward_var = (torch.cat(viterbivars_t) + feat).view(1, -1)
            backpointers.append(bptrs_t)

        # Transition to STOP_TAG
        terminal_var = forward_var + self.transitions[EOS_IDX]
        best_tag_id = argmax(terminal_var)
        path_score = terminal_var[0][best_tag_id]

        # Follow the back pointers to decode the best path.
        best_path = [best_tag_id]
        for bptrs_t in reversed(backpointers):
            best_tag_id = bptrs_t[best_tag_id]
            best_path.append(best_tag_id)
        # Pop off the start tag (we dont want to return that to the caller)
        start = best_path.pop()
        assert start == BOS_IDX  # Sanity check
        best_path.reverse()
        return path_score, best_path

    def neg_log_likelihood(self, sentence, tags):
        feats = self._get_transformer_features(sentence,tags)
        forward_score = self._forward_alg(feats, tags)
        gold_score = self._score_sentence(feats, tags)
        return forward_score - gold_score

    def forward(self, sentence,tags):  # dont confuse this with _forward_alg above.
        # Get the emission scores from the BiLSTM
        lstm_feats = self._get_transformer_features(sentence,tags)

        # Find the best path, given the features.
        score, tag_seq = self._viterbi_decode(lstm_feats)
        return score, tag_seq

In [None]:
torch.manual_seed(0)

SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 1
FFN_HID_DIM = 512
BATCH_SIZE = 128
NUM_ENCODER_LAYERS = 1
NUM_DECODER_LAYERS = 1

crf_transformer = Transformer_CRF(NUM_ENCODER_LAYERS, NUM_DECODER_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in crf_transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

crf_transformer = crf_transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(crf_transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

In [None]:
from torch.utils.data import DataLoader

def train_epoch(model, optimizer):
    model.train()
    losses = 0
    f=io.open(TRAINPATH, encoding="utf8")
    train_iter=f.readlines()
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    i=0
    for src, tgt, tgt_input in train_dataloader:
        print(f'{i}/{len(train_dataloader)}')
        i+=1
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        
        optimizer.zero_grad()

        loss=model.neg_log_likelihood(src,tgt[1:,:])
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0
    f=io.open(VALPATH, encoding="utf8")
    val_iter=f.readlines()
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt, tgt_input in val_dataloader:
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)

        loss=model.neg_log_likelihood(src,tgt)
        
        losses += loss.item()

    return losses / len(val_dataloader)

In [None]:
import numpy as np
model_parameters = filter(lambda p: p.requires_grad, crf_transformer.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

24732050

In [None]:
import time
from timeit import default_timer as timer
NUM_EPOCHS = 18
patience=3
early_stopping = EarlyStopping(patience=patience, verbose=True)
start=time.time()
for epoch in range(1, NUM_EPOCHS+1):
    start_time = timer()
    train_loss = train_epoch(crf_transformer, optimizer)
    end_time = timer()
    val_loss = evaluate(crf_transformer)

    early_stopping(val_loss, crf_transformer)
    if early_stopping.early_stop:
      print("Early stopping")
      break
    print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))
end=time.time()
crf_transformer=torch.load('checkpoint.pt')
traintime=end-start
traintime

0/249
1/249
2/249
3/249
4/249
5/249
6/249
7/249
8/249
9/249
10/249
11/249
12/249
13/249
14/249
15/249
16/249
17/249
18/249
19/249
20/249
21/249
22/249
23/249
24/249
25/249
26/249
27/249
28/249
29/249
30/249
31/249
32/249
33/249
34/249
35/249
36/249
37/249
38/249
39/249
40/249
41/249
42/249
43/249
44/249
45/249
46/249
47/249
48/249
49/249
50/249
51/249
52/249
53/249
54/249
55/249
56/249
57/249
58/249
59/249
60/249
61/249
62/249
63/249
64/249
65/249
66/249
67/249
68/249
69/249
70/249
71/249
72/249
73/249
74/249
75/249
76/249
77/249
78/249
79/249
80/249
81/249
82/249
83/249
84/249
85/249
86/249
87/249
88/249
89/249
90/249
91/249
92/249
93/249
94/249
95/249
96/249
97/249
98/249
99/249
100/249
101/249
102/249
103/249
104/249
105/249
106/249
107/249
108/249
109/249
110/249
111/249
112/249
113/249
114/249
115/249
116/249
117/249
118/249
119/249
120/249
121/249
122/249
123/249
124/249
125/249
126/249
127/249
128/249
129/249
130/249
131/249
132/249
133/249
134/249
135/249
136/249
137/249
138/24

11963.448099851608