In [None]:
import torch
import torchtext.vocab
from torch import nn
import torchtext
import  numpy as np
import torchdata
from torchtext.vocab import vocab
from torchtext import transforms as T
from typing import List, Tuple

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import math
from typing import Tuple
import numpy as np


class MultiHeadSelfAttention(nn.Module):
    @staticmethod
    def _attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
                   d_k: int = None,
                   mask: torch.tensor = None, dropout: float = None):
        if d_k is None:
            d_k = k.shape[-1]
        # transpose k's last two dimensions
        k = torch.transpose(k, -1, -2)
        attn = torch.matmul(q / math.sqrt(d_k), k)

        if mask is not None:
            attn = torch.masked_fill(attn, mask, 1e-9)
        attn = F.softmax(attn, dim=-1)

        if dropout is not None:
            attn = F.dropout(attn, dropout)

        out = torch.matmul(attn, v)

        return out, attn

    linear_q: nn.Linear
    linear_k: nn.Linear
    linear_v: nn.Linear
    fc: nn.Linear
    dropout: float
    dropper: nn.Module
    norm: nn.LayerNorm

    head: int
    d_k: int
    d_v: int

    def __init__(self, d_emb: int, head: int, d_k: int,
                 d_v: int = None,
                 dropout: float = 0.1,
                 bias_qkv: bool = False):
        super().__init__()
        if d_v is None:
            d_v = d_k
        self.linear_q = nn.Linear(d_emb, head * d_k, bias=bias_qkv)
        self.linear_k = nn.Linear(d_emb, head * d_k, bias=bias_qkv)
        self.linear_v = nn.Linear(d_emb, head * d_v, bias=bias_qkv)
        self.fc = nn.Linear(head * d_v, d_emb, bias=bias_qkv)
        self.dropout = dropout

        if dropout is not None:
            self.dropper = nn.Dropout(dropout)
        else:
            self.dropper = nn.Identity()

        self.norm = nn.LayerNorm(d_emb, eps=1e-6)

        self.head = head
        self.d_k = d_k
        self.d_v = d_v

    def forward(self, q: torch.Tensor, k: torch.Tensor = None, v: torch.Tensor = None,
                mask=None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward for transformer
        :param q: query, tensor(N, *, d_emb)
        :param k: key (default q), tensor(N, *, d_emb)
        :param v: value (default v), tensor(N, *, d_emb)
        :param mask: mask, tensor(sequence)
        :return: weight, attention
        """
        resid = q
        if k is None:
            k = q
        if v is None:
            v = q
        if mask is not None:
            mask = torch.unsqueeze(mask, -3)  # mask on the expected head dim

        q, k, v = self.linear_q(q), self.linear_k(k), self.linear_v(v)  # (N, S, H * S)
        q = torch.transpose(q.view(*q.shape[:-1], self.head, self.d_k), -2, -3)  # (N, H, S, S)
        k = torch.transpose(k.view(*k.shape[:-1], self.head, self.d_k), -2, -3)  # (N, H, S, S)
        v = torch.transpose(v.view(*v.shape[:-1], self.head, self.d_v), -2, -3)  # (N, H, S, S)

        weight, attn = self._attention(q, k, v, d_k=self.d_k, mask=mask, dropout=self.dropout)
        weight = weight.transpose(-2, -3)
        weight = weight.contiguous().view(*weight.shape[:-2], -1)
        weight = self.dropper(self.fc(weight))

        weight += resid

        weight = self.norm(weight)

        return weight, attn


class FeedForward(nn.Module):
    dropout: float
    w: nn.Module
    norm: nn.Module

    def __init__(self, dim: int, hidden_dim: int, dropout: float):
        super().__init__()
        self.dropout = dropout
        self.w = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
        self.norm = nn.LayerNorm(dim, eps=1e-6)

    def forward(self, x):
        resid = x
        x = self.w(x)
        x += resid
        x = self.norm(x)
        return x


class TransformerEncodingLayer(nn.Module):
    attention: nn.Module
    feed_forward: nn.Module

    def __init__(self, d_emb: int,
                 head: int = 1,
                 d_k: int = 1024, d_v: int = 1024,
                 attention_dropout: float = 0.1,
                 forward_hidden: int = 1024,
                 forward_dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadSelfAttention(d_emb, head, d_k, d_v, dropout=attention_dropout)
        self.feed_forward = FeedForward(d_emb, forward_hidden, forward_dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        x, _ = self.attention(x, mask=mask)
        x = self.feed_forward(x)
        return x


class TransformerEncoder(nn.Module):
    layer: int
    layers: nn.ModuleList
    norm: nn.Module

    def __init__(self, layer: int, d_emb: int,
                 head: int = 1, d_k=1024, d_v=1024,
                 attention_dropout: float = 0.1,
                 forward_hidden: int = 1024, forward_dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [TransformerEncodingLayer(d_emb, head, d_k, d_v, attention_dropout, forward_hidden, forward_dropout)
             for _ in range(layer)]
        )
        self.layer = layer
        self.norm = nn.LayerNorm(d_emb, eps=1e-6)

    def forward(self, x: torch.Tensor, mask: torch.Tensor):
        for layer in self.layers:
            x = layer(x, mask=mask)
        return x


class TransformerDecodingLayer(nn.Module):
    attention: nn.Module
    encoder_decoder_attention: nn.Module
    feed_forward: nn.Module

    def __init__(self, d_emb: int,
                 head: int = 1,
                 d_k: int = 1024, d_v: int = 1024,
                 attention_dropout: float = 0.1,
                 forward_hidden: int = 1024,
                 forward_dropout: float = 0.1):
        super().__init__()
        self.attention = MultiHeadSelfAttention(d_emb, head, d_k, d_v, dropout=attention_dropout)
        self.encoder_decoder_attention = MultiHeadSelfAttention(d_emb, head, d_k, d_v, dropout=attention_dropout)
        self.feed_forward = FeedForward(d_emb, forward_hidden, forward_dropout)

    def forward(self, x: torch.Tensor,
                encoder_kv: torch.Tensor,
                mask_encoder: torch.Tensor,
                mask_decoder: torch.Tensor) -> torch.Tensor:
        """
        :param x: [N, seq, dim]
        :param encoder_kv: [N, seq, dim]
        :param mask_encoder: [N, seq, seq]
        :param mask_decoder: [N, seq, seq]
        :return:
        """
        x, _ = self.attention(x, mask=mask_decoder)
        # Decoder in transformer depends on encoder's output
        x, _ = self.attention(x, encoder_kv, encoder_kv, mask=mask_encoder)
        x = self.feed_forward(x)
        return x


class TransformerDecoder(nn.Module):
    layer: int
    layers: nn.ModuleList
    norm: nn.Module

    def __init__(self, layer: int, d_emb: int,
                 head: int = 1, d_k=1024, d_v=1024,
                 attention_dropout: float = 0.1,
                 forward_hidden: int = 1024, forward_dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [
                TransformerDecodingLayer(d_emb, head, d_k, d_v, attention_dropout, forward_hidden, forward_dropout)
                for _ in range(layer)
            ]
        )
        self.layer = layer
        self.norm = nn.LayerNorm(d_emb, eps=1e-6)

    def forward(self, x: torch.Tensor,
                encoder_output: torch.Tensor,
                mask_encoder: torch.Tensor,
                mask_decoder: torch.Tensor):
        """
        :param x: [N, seq, dim]
        :param encoder_output: [N, seq, dim]
        :param mask_encoder: [N, seq, seq]
        :param mask_decoder: [N, seq, seq]
        :return:
        """
        for layer in self.layers:
            x = layer(x, encoder_output, mask_encoder, mask_decoder)
        return x


class PositionalEncoding(nn.Module):
    pos: torch.Tensor

    def __init__(self, dim: int, seq: int, device: torch.device = torch.device("cpu:0")):
        super().__init__()
        pe = torch.zeros(seq, dim)
        position = np.array([
            [0 if i == 0 else (pos / (10000 ** (2 * i / dim))) for i in range(dim)]
            for pos in range(seq)
        ], dtype=float)
        position[1:, 0::2] = np.sin(position[1:, 0::2])
        position[1:, 1::2] = np.cos(position[1:, 1::2])
        self.pos = torch.from_numpy(position).float().to(device)

    def forward(self, x: torch.Tensor):
        return x + self.pos


class Transformer(nn.Module):

    @staticmethod
    def get_attention_pad_mask(seq: torch.Tensor, seq_length: int, empty: int) -> torch.Tensor:
        """
        Get sequence mask
        :param seq: [N, seq_length]
        :param seq_length: seq_length
        :param empty: empty token
        :return:
        """
        msk = torch.eq(seq, empty).byte().unsqueeze(1)
        return msk.expand(seq.shape[0], seq_length, seq_length)

    @staticmethod
    def get_attention_sequence_mask(batch_size: int, seq_length: int, dev: torch.device = torch.device("cpu:0")):
        shape = (batch_size, seq_length, seq_length)
        subsequence_mask = torch.from_numpy(np.triu(np.ones(shape, dtype=int), k=1)).byte().to(dev)
        return subsequence_mask

    encoder: TransformerEncoder
    decoder: TransformerDecoder

    emb_encode: nn.Embedding
    emb_decode: nn.Embedding

    position_encode: PositionalEncoding
    position_decode: PositionalEncoding

    projection: nn.Module

    seq: int

    empty: int

    device: torch.device

    def __init__(self,
                 dict_size_encode: int, dict_size_decode: int,
                 d_emb: int, seq: int,
                 layer: int = 6,
                 head: int = 1, d_k: int = 1024, d_v: int = 1024,
                 attention_dropout: float = 0.1,
                 forward_hidden: int = 2048,
                 forward_dropout: float = 0.1,
                 emb_encode: torch.Tensor = None,
                 emb_decode: torch.Tensor = None,
                 empty: int = 0,
                 device: torch.device = torch.device("cpu:0")
                 ):
        super().__init__()

        self.seq = seq

        self.encoder = TransformerEncoder(layer, d_emb, head, d_k, d_v, attention_dropout, forward_hidden,
                                          forward_dropout)
        self.decoder = TransformerDecoder(layer, d_emb, head, d_k, d_v, attention_dropout, forward_hidden,
                                          forward_dropout)

        self.projection = nn.Sequential(
            nn.Linear(d_emb, dict_size_decode, bias=False)
        )

        self.position_encode = PositionalEncoding(d_emb, seq, device=device)
        self.position_decode = PositionalEncoding(d_emb, seq, device=device)

        self.emb_encode = nn.Embedding(dict_size_encode, d_emb) if emb_encode is None \
            else nn.Embedding.from_pretrained(emb_encode, freeze=True)

        self.emb_decode = nn.Embedding(dict_size_decode, d_emb) if emb_decode is None \
            else nn.Embedding.from_pretrained(emb_decode, freeze=True)

        self.empty = empty

        self.device = device

    def encode(self, encoder_input: torch.Tensor) -> torch.Tensor:
        mask = self.get_attention_pad_mask(encoder_input, self.seq, self.empty) == 1
        return self.encoder(self.position_encode(self.emb_encode(encoder_input)), mask)

    def decode(self, decoder_input: torch.Tensor, encoder_input: torch.Tensor, encoder_output: torch.Tensor):
        dec_mask = torch.gt(
            self.get_attention_sequence_mask(decoder_input.shape[0], self.seq, self.device)
            + self.get_attention_pad_mask(decoder_input, self.seq, self.empty),
            0
        )
        enc_mask = self.get_attention_pad_mask(encoder_input, self.seq, self.empty) == 1
        decoded = self.decoder(
            self.position_decode(self.emb_decode(decoder_input)),
            encoder_output, enc_mask, dec_mask
        )
        return self.projection(decoded)

    def forward(self, encoder_input: torch.Tensor, decoder_input: torch.Tensor):
        enc = self.encode(encoder_input)
        return self.decode(decoder_input, encoder_input, enc)


In [None]:
SEED = 1234
torch.manual_seed(SEED)

<torch._C.Generator at 0x7f4fcc0484d0>

In [None]:
from torchtext.data.utils import get_tokenizer

tokenizer = get_tokenizer('basic_english')

In [None]:
train_iter, test_iter = torchtext.datasets.IMDB(root='./data')
train_list = list(train_iter)
test_list = list(test_iter)

In [None]:
MOST_COMMON_SIZE = 30000
SENTENCE_LENGTH = 300
EMBED_SIZE = 50
MODEL_LOADING = "model"

In [None]:
from collections import Counter, OrderedDict

counter = Counter()
for (label, line) in train_list:
    counter.update(tokenizer(line))

In [None]:
most_common_words = counter.most_common(MOST_COMMON_SIZE + 10)[10:]
PAD = '<PAD>'
UNK = '<UNK>'
BOS = '<BOS>'
EOS = '<EOS>'
eng_vocab = vocab(OrderedDict(most_common_words), specials=[PAD, UNK, BOS, EOS])
eng_vocab.set_default_index(eng_vocab[UNK])

In [None]:
glove = torchtext.vocab.GloVe(name='6B', dim=EMBED_SIZE)

.vector_cache/glove.6B.zip: 862MB [02:41, 5.33MB/s]                           
100%|█████████▉| 399999/400000 [00:10<00:00, 37325.28it/s]


In [None]:
embedding_weight_matrix = [[0 for i in range(EMBED_SIZE)] for j in range(4)]
embedding_weight_matrix.extend(
    [glove.get_vecs_by_tokens(word[0]) for word in most_common_words]
)
embedding_weight_matrix = torch.tensor(embedding_weight_matrix)

In [None]:
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

VOCAB_SIZE = len(eng_vocab)

class IMDBDataset(torch.utils.data.Dataset):
    size: int
    ret: list
    targs: list
    device: torch.device

    def __init__(self, lst: list, dev: torch.device = None):
        super().__init__()
        self.ret = []
        on = 0
        self.size = len(lst)
        self.device = torch.device('cpu:0') if dev is None else dev
        self.targs = []
        print('Loading Data Set')
        tot = len(lst)
        cnt = 0
        for lbl, comment in lst:
            dt = tokenizer(comment)
            self.ret.append(
                [eng_vocab[BOS]] + [eng_vocab[word] for word in dt[:SENTENCE_LENGTH]] + [eng_vocab[EOS]]
            )
            if len(dt) < SENTENCE_LENGTH:
                self.ret[-1].extend([eng_vocab[PAD]] * (SENTENCE_LENGTH - len(dt)))
            self.ret[-1] = torch.tensor(self.ret[-1], dtype=torch.long).to(device=dev)
            self.targs.append(
                torch.tensor([1, 0] if lbl == 1 else [0, 1],
                             dtype=torch.float)
                .to(device=dev)
            )
            cnt += 1
            print(f"\r{cnt} / {tot}         ", end='')
        print()

    def __len__(self):
        return self.size

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        return self.ret[idx], self.targs[idx]

In [None]:
device = torch.device("cuda")

In [None]:
imdb_train_set = IMDBDataset(train_list)
imdb_test_set = IMDBDataset(test_list)

Loading Data Set
25000 / 25000         
Loading Data Set
25000 / 25000         


In [None]:
train_data_loader = DataLoader(imdb_train_set, batch_size=64, shuffle=True)
test_data_loader = DataLoader(imdb_test_set, batch_size=64, shuffle=True)

In [None]:
class Classifier(nn.Module):

    emb: nn.Module
    transformer: TransformerEncoder
    recurrent_part: nn.Module
    predict: nn.Module

    empty: int

    def __init__(self, out_dim: int,
                 transformer_layer: int = 6,
                 gru_layer: int = 2, gru_dim: int = 512,
                 forward_dim: int = 1024,
                 empty: int = 0):
        super().__init__()
        self.emb = nn.Embedding.from_pretrained(embedding_weight_matrix, freeze=True)
        self.transformer = TransformerEncoder(
            transformer_layer, EMBED_SIZE, forward_dropout=0.2, attention_dropout=0.2
        )
        self.recurrent_part = nn.Sequential(
            nn.Dropout(0.2),
            nn.GRU(EMBED_SIZE, gru_dim, num_layers=gru_layer, batch_first=True)
        )
        self.predict = nn.Sequential(
            nn.Linear(gru_dim, forward_dim),
            nn.LeakyReLU(0.2),
            nn.Linear(forward_dim, out_dim),
            nn.Softmax(dim=1)
        )

        self.empty = empty

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        msk = Transformer.get_attention_pad_mask(x, x.shape[1], self.empty) == 1
        x = self.emb(x)
        x = self.transformer(x, msk)
        all_result, final_result = self.recurrent_part(x)
        final_result = final_result[-1]
        return self.predict(final_result)

In [None]:
model = Classifier(out_dim=2, empty=eng_vocab[PAD]).to(device)
if MODEL_LOADING is not None:
    model.load_state_dict(torch.load(MODEL_LOADING))

In [None]:
loss = nn.BCELoss().to(device)
optim = torch.optim.Adam(model.parameters(), lr=1e-5)

In [None]:

@torch.no_grad()
def test(i: int, name: str, data_set: torch.utils.data.Dataset, data_loader: torch.utils.data.DataLoader):
    model.eval()
    cnt = 0
    tot = len(data_set)
    ls = 0
    for data, lbl in data_loader:
        data = data.to(device)
        lbl = lbl.to(device)
        oup = model(data)
        out = torch.argmax(oup, dim=1)
        targ = torch.argmax(lbl, dim=1)
        cnt += torch.sum(out == targ).item()
        ls += loss(oup, lbl).item()
    ls /= len(data_loader)

    print()
    print(name)
    print(f'Correct: {cnt}')
    print(f'Wrong: {tot - cnt}')
    print(f'Total: {tot}')
    print(f'Correctness: {round(cnt / tot, 5)}')
    print(f'Loss: {round(ls, 5)}')
    print("=" * 40)
    model.train()


In [None]:
EPOCH = 400
for epoch in range(1, EPOCH + 1):
    print(f"Epoch {epoch}:")
    ls_avg = 0
    cnt = 0
    tot = len(train_data_loader)
    for data, lbl in train_data_loader:
        optim.zero_grad()
        data = data.to(device)
        lbl = lbl.to(device)
        y = model(data)
        ls = loss(y, lbl)
        ls.backward()
        optim.step()
        ls_avg += ls.item()
        cnt += 1
        print(f"\r{cnt} / {tot} : Loss: {round(ls.item(), 5)}                       ", end="")
    print(f"\rLoss: {round(ls_avg / len(train_data_loader), 5)}")
    torch.save(model.state_dict(), f"models/model_{epoch}")
    if epoch % 3 == 0:
      test(epoch, "Test:", imdb_test_set, test_data_loader)

Epoch 1:
Loss: 0.28398
Epoch 2:
Loss: 0.28529
Epoch 3:
Loss: 0.28434

Test:
Correct: 21228
Wrong: 3772
Total: 25000
Correctness: 0.84912
Loss: 0.37428
Epoch 4:
Loss: 0.28041
Epoch 5:
Loss: 0.27898
Epoch 6:
Loss: 0.28042

Test:
Correct: 21211
Wrong: 3789
Total: 25000
Correctness: 0.84844
Loss: 0.36295
Epoch 7:
Loss: 0.27835
Epoch 8:
Loss: 0.27488
Epoch 9:
Loss: 0.27316

Test:
Correct: 21465
Wrong: 3535
Total: 25000
Correctness: 0.8586
Loss: 0.34163
Epoch 10:
Loss: 0.27294
Epoch 11:
Loss: 0.27268
Epoch 12:
Loss: 0.27432

Test:
Correct: 21243
Wrong: 3757
Total: 25000
Correctness: 0.84972
Loss: 0.35947
Epoch 13:
Loss: 0.27005
Epoch 14:
Loss: 0.27088
Epoch 15:
Loss: 0.26949

Test:
Correct: 21530
Wrong: 3470
Total: 25000
Correctness: 0.8612
Loss: 0.33814
Epoch 16:
Loss: 0.26874
Epoch 17:
Loss: 0.26605
Epoch 18:
Loss: 0.26609

Test:
Correct: 21442
Wrong: 3558
Total: 25000
Correctness: 0.85768
Loss: 0.35382
Epoch 19:
Loss: 0.26486
Epoch 20:
Loss: 0.26124
Epoch 21:
Loss: 0.26345

Test:
Correct:

KeyboardInterrupt: ignored