In [74]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

device=torch.device('mps')

In [75]:
""" sinusoid position embedding """
def get_sinusoid_encoding_table(n_seq, d_hidn):
    def cal_angle(position, i_hidn):
        return position / np.power(10000, 2 * (i_hidn // 2) / d_hidn)
    def get_posi_angle_vec(position):
        return [cal_angle(position, i_hidn) for i_hidn in range(d_hidn)]

    sinusoid_table = np.array([get_posi_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])  # even index sin 
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])  # odd index cos

    return sinusoid_table

In [76]:
""" scale dot product attention """
class ScaledDotProductAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(config["dropout"])
        self.scale = 1 / (self.config["d_head"] ** 0.5)
    
    def forward(self, Q, K, V):
        # (bs, n_head, n_q_seq, n_k_seq)
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        # (bs, n_head, n_q_seq, n_k_seq)
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.dropout(attn_prob)
        # (bs, n_head, n_q_seq, d_v)
        context = torch.matmul(attn_prob, V)
        # (bs, n_head, n_q_seq, d_v), (bs, n_head, n_q_seq, n_v_seq)
        return context, attn_prob

In [77]:
""" multi head attention """
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.W_Q = nn.Linear(self.config["d_hidn"], self.config["n_head"] * self.config["d_head"])
        self.W_K = nn.Linear(self.config["d_hidn"], self.config["n_head"] * self.config["d_head"])
        self.W_V = nn.Linear(self.config["d_hidn"], self.config["n_head"] * self.config["d_head"])
        self.scaled_dot_attn = ScaledDotProductAttention(self.config)
        self.linear = nn.Linear(self.config["n_head"] * self.config['d_head'], self.config['d_hidn'])
        self.dropout = nn.Dropout(config['dropout'])
    
    def forward(self, Q, K, V):
        batch_size = Q.size(0)
        # (bs, n_head, n_q_seq, d_head)
        q_s = self.W_Q(Q).view(batch_size, -1, self.config['n_head'], self.config['d_head']).transpose(1,2)
        # (bs, n_head, n_k_seq, d_head)
        k_s = self.W_K(K).view(batch_size, -1, self.config['n_head'], self.config['d_head']).transpose(1,2)
        # (bs, n_head, n_v_seq, d_head)
        v_s = self.W_V(V).view(batch_size, -1, self.config['n_head'], self.config['d_head']).transpose(1,2)

        # (bs, n_head, n_q_seq, n_k_seq)
        #attn_mask = attn_mask.unsqueeze(1).repeat(1, self.config.n_head, 1, 1)

        # (bs, n_head, n_q_seq, d_head), (bs, n_head, n_q_seq, n_k_seq)
        context, attn_prob = self.scaled_dot_attn(q_s, k_s, v_s)
        # (bs, n_head, n_q_seq, h_head * d_head)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.config['n_head'] * self.config['d_head'])
        # (bs, n_head, n_q_seq, e_embd)
        output = self.linear(context)
        output = self.dropout(output)
        # (bs, n_q_seq, d_hidn), (bs, n_head, n_q_seq, n_k_seq)
        return output, attn_prob


In [78]:
""" feed forward """
class PoswiseFeedForwardNet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.conv1 = nn.Conv1d(in_channels=self.config['d_hidn'], out_channels=self.config['d_ff'], kernel_size=1)
        self.conv2 = nn.Conv1d(in_channels=self.config['d_ff'], out_channels=self.config['d_hidn'], kernel_size=1)
        self.active = F.gelu
        self.dropout = nn.Dropout(config['dropout'])

    def forward(self, inputs):
        # (bs, d_ff, n_seq)
        output = self.active(self.conv1(inputs.transpose(1, 2)))
        # (bs, n_seq, d_hidn)
        output = self.conv2(output).transpose(1, 2)
        output = self.dropout(output)
        # (bs, n_seq, d_hidn)
        return output


In [79]:
""" encoder layer """
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.self_attn = MultiHeadAttention(self.config)
        self.layer_norm1 = nn.LayerNorm(self.config['d_hidn'], eps=self.config['layer_norm_epsilon'])
        self.pos_ffn = PoswiseFeedForwardNet(self.config)
        self.layer_norm2 = nn.LayerNorm(self.config['d_hidn'], eps=self.config['layer_norm_epsilon'])
    
    def forward(self, inputs, attn_mask):
        # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
        att_outputs, attn_prob = self.self_attn(inputs, inputs, inputs)
        att_outputs = self.layer_norm1(inputs + att_outputs)
        # (bs, n_enc_seq, d_hidn)
        ffn_outputs = self.pos_ffn(att_outputs)
        ffn_outputs = self.layer_norm2(ffn_outputs + att_outputs)
        # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
        return ffn_outputs, attn_prob

In [80]:
""" encoder """
class Encoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.enc_emb = nn.Embedding(self.config["n_enc_vocab"], self.config["d_hidn"])
        sinusoid_table = torch.FloatTensor(get_sinusoid_encoding_table(self.config['n_enc_seq'] + 1, self.config['d_hidn']))
        self.pos_emb = nn.Embedding.from_pretrained(sinusoid_table, freeze=True)

        self.layers = nn.ModuleList([EncoderLayer(self.config) for _ in range(self.config['n_layer'])])
    
    def forward(self, inputs):
        positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1
        pos_mask = inputs.eq(self.config.i_pad)
        positions.masked_fill_(pos_mask, 0)

        # (bs, n_enc_seq, d_hidn)
        outputs = self.enc_emb(inputs) + self.pos_emb(positions)


        attn_probs = []
        for layer in self.layers:
            # (bs, n_enc_seq, d_hidn), (bs, n_head, n_enc_seq, n_enc_seq)
            outputs, attn_prob = layer(outputs)
            attn_probs.append(attn_prob)
        # (bs, n_enc_seq, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        return outputs, attn_probs


In [81]:
class encoderclf(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.encoder = Encoder(self.config)
        self.projection = nn.Linear(self.config['d_hidn'], self.config['n_output'], bias=False)

    def forward(self, enc_inputs, dec_inputs):
        # (bs, n_enc_seq, d_hidn), [(bs, n_head, n_enc_seq, n_enc_seq)]
        enc_outputs, enc_self_attn_probs = self.encoder(enc_inputs)
        enc_outputs, _ = torch.max(enc_outputs, dim=1)
        logits = self.projection(enc_outputs)
        return logits, enc_self_attn_probs

In [82]:
def make_model():
    config={
        "n_enc_vocab": len(vocab_src),
        "n_enc_seq": 256,
        "n_dec_seq": 256,
        "n_layer": 6,
        "d_hidn": 256,
        "i_pad": 0,
        "d_ff": 1024,
        "n_head": 4,
        "d_head": 64,
        "dropout": 0.1,
        "layer_norm_epsilon": 1e-12,
        "n_output":2
    }

    model=encoderclf(config)
    return model

In [83]:
from torch.utils.data import DataLoader
from torchtext.datasets import IMDB
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator

tokenizer=get_tokenizer('basic_english')
train_iter = IMDB(split='train')

def yield_tokens(data_iter, tokenizer, index):
    for from_to_tuple in data_iter:
        yield tokenizer(from_to_tuple[index])

vocab_src = build_vocab_from_iterator(yield_tokens(train_iter, tokenizer, 0), specials=["<unk>"])
vocab_src.set_default_index(vocab_src["<unk>"])

text_pipeline = lambda x: vocab_src(tokenizer(x))
label_pipeline = lambda x: 1. if (x=='pos') else 0


In [84]:
def collate_batch(
    batch,
    max_padding=128,
    pad_id=2,
):
    src_list, label_list = [], []
    for (text, label) in batch:
        processed_src = torch.cat(
            [
                torch.tensor(
                    text_pipeline(text),
                    dtype=torch.float,
                ),
            ],
            0,
        )
        src_list.append(
            # warning - overwrites values for negative values of padding - len
            F.pad(
                processed_src,
                (
                    0,
                    max_padding - len(processed_src),
                ),
                value=pad_id,
            )
        )
        label_list.append(label_pipeline(label))
    src = torch.cat(src_list).to(device)
    tgt = torch.tensor(label_list, dtype=torch.float).to(device)
    return (src, tgt)

In [85]:
model=make_model().to(device)