In [2]:
import torch
from torch import nn
from torch import optim
from torch.utils import data as D

import torch.nn.functional as F

import pytorch_lightning as pl

import pandas as pd

from data_utils.preprocess import process_tweet, batch_tokens
from data_utils.tokenization import SentencePieceTokenizer

import numpy as np
from tqdm import tqdm

from typing import List

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [3]:
device = torch.device("cpu")

In [4]:
teacher = torch.load("lm.pth", map_location=device)
teacher.aux_lm_loss = True

In [5]:
tokenizer = SentencePieceTokenizer(model_path="../models/ama_32k_tokenizer.model")

In [6]:
class DistillationDataset(D.Dataset):
    def __init__(self, text_csv_file: str, text_column: str, tokenizer: SentencePieceTokenizer, preprocess_fn):
        super(DistillationDataset, self).__init__()
        self.table: pd.Dataset = pd.read_csv(text_csv_file, memory_map=True)
        
        self.column = text_column
        self.tokenizer = tokenizer
        self.preprocess_fn = preprocess_fn
     
    def tokenize(self, tokenizer, text):
        """
        Tokenizes a text using SentencePiece tokenizer
        """
        input_text = self.tokenizer.EncodeAsIds(text, self.preprocess_fn).tokenization
    
        return input_text
    
    @staticmethod
    def get_collate_fn():
        def batch_tokens(token_lists, tensor_type=torch.LongTensor, fill_value=0):
            lens = torch.from_numpy(np.array(list(map(len, token_lists)), dtype=np.int64))
            batch_tensor = fill_value * torch.ones(len(lens), max(lens)).type(tensor_type)
            for i, string in enumerate(token_lists):
                _tokenize_str(string, tensor_type, batch_tensor[i])
            return batch_tensor.permute(1, 0), lens - 1

        def _tokenize_str(data, tensor_type, char_tensor=None):
            """
            Parses a utf-8 encoded string and assigns to ByteTensor char_tensor.
            If no char_tensor is provide one is created.
            Typically used internally by `tokenize_str_batch`.
            """
            if char_tensor is None:
                if isinstance(data, str):
                    # data could either be a string or a list of ids.
                    data = data.encode()
                char_tensor = tensor_type(len(data))
            for i, char in enumerate(data):
                char_tensor[i] = char
                
        return batch_tokens

        
    def __len__(self) -> int:
        return len(self.table)
    
    def __getitem__(self, idx: int) -> List[str]:
        sample = self.table[self.column].iloc[idx]
        sample = self.tokenize(self.tokenizer, sample)
        
        return sample

In [7]:
teacher.encoder.encoder.embed_tokens

Embedding(32001, 768, padding_idx=0)

In [8]:
base_dataset = DistillationDataset("../data/test.csv", 'Tweet', tokenizer=tokenizer, preprocess_fn=process_tweet)

In [135]:
train, valid, test = D.Subset(base_dataset, np.arange(0, 9500000)), \
                     D.Subset(base_dataset, np.arange(9500000, 9750000)), \
                     D.Subset(base_dataset, np.arange(9750000, len(base_dataset)))

In [9]:
train_dl = D.DataLoader(base_dataset, batch_size=10, collate_fn=base_dataset.get_collate_fn())

In [10]:
txt, ln = next(iter(train_dl))

In [11]:
txt.shape, ln.shape

(torch.Size([33, 10]), torch.Size([10]))

In [12]:
ln

tensor([29, 26,  8, 19, 31, 32, 19, 18, 27, 28])

In [13]:
with torch.no_grad():
    feats, lm = teacher(txt, ln)



In [14]:
lm = torch.softmax(lm, -1)

In [26]:
sub_lm = lm[:, :1]

In [27]:
sub_lm.shape

torch.Size([30, 1, 32001])

In [23]:
def decode(tensor, true_value, example_id=0, tokenizer = tokenizer):
    sequence = tensor[:, example_id]
    true_value = tokenizer.EncodeAsIds(true_value).tokenization
    print(sequence.shape)
    for i, dist in enumerate(sequence):
        idx   = torch.argmax(dist)
        value = torch.max(dist)
        if idx != 0:
            token = tokenizer.DecodeIds(idx.item())
        else:
            token = "<pad>"
        if i < len(true_value):
            print(f"{token:^20} [{value.item():.4f}] | {tokenizer.DecodeIds(true_value[i]):^20} [{dist[true_value[i]]:.4f}]")

In [26]:
decode(lm, base_dataset.table.Tweet.iloc[5], example_id=5, tokenizer=tokenizer)

torch.Size([33, 32001])
         ⁇           [0.9977] |                      [0.0000]
         '           [0.1075] |          ⁇           [0.0000]
         .           [0.0675] |          P           [0.0006]
        ISH          [0.2197] |          OL          [0.0085]
        ICAL         [0.5260] |          IT          [0.0003]
         IT          [0.2743] |          IC          [0.0032]
         ⁇           [0.0936] |          O           [0.0007]
         ⁇           [0.0693] |          E           [0.0015]
         an          [0.9228] |        urope         [0.0000]
        Fact         [0.0861] |     Interesting      [0.0006]
         of          [0.6880] |        choice        [0.0002]
        the          [0.0280] |          of          [0.0001]
        and          [0.1063] |        words         [0.0001]
        from         [0.0229] |         ...          [0.0001]
       there         [0.0987] |         Are          [0.0000]
      familiar       [0.1036] |         you   

In [27]:
base_dataset.table.Tweet.iloc[5]

'@POLITICOEurope Interesting choice of words... Are you confirming that governments fund #terrorism? Bit of an open door, but still...'

In [19]:
ppls = []
for txt, ln in tqdm(train_dl):
    txt = txt.to(device)
    x = txt[:, :-1]
    y = txt[:, 1:]
    ln = ln.to(device) - 1
    with torch.no_grad():
        feats, outs = teacher(x, ln)
        loss = F.cross_entropy(outs, y.reshape(-1), ignore_index=0)
        ppls.append(loss.item())

  0%|          | 0/51 [00:00<?, ?it/s]


ValueError: Expected input batch_size (64) to match target batch_size (3648).

In [11]:
txt = txt.to(device)
ln = ln.to(device)

In [12]:
feats, outs = teacher(txt, ln)



In [84]:
from argparse import Namespace

In [85]:
args = Namespace()
args.vocab_size = 32001
args.hidden_size = 128
args.blocks_size = 256
args.n_blocks = 12

In [86]:
dummy = torch.randint(32001, size=(55, 32))

In [87]:
dummy.shape

torch.Size([55, 32])

In [88]:
class ResConvBlock(nn.Module):
        
    def __init__(self, input_size, output_size, kernel_size=3, activation=F.gelu):
        super(ResConvBlock, self).__init__()
        self.perform_residual = (input_size == output_size)
        self.activation = activation

        self.cnn = nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=kernel_size, padding=(kernel_size // 2))
        self.cnn_ln = nn.LayerNorm(output_size)
            
        self.ff = nn.Linear(in_features=output_size, out_features=output_size)
        self.ff_ln = nn.LayerNorm(output_size)
            
    def forward(self, x):
        """
        X is a tensor of shape [TimeSteps x BatchSize x InputSize]
        :return Tensor of shape [TimeSteps x BatchSize x OutputSize]
        """
        if self.perform_residual:
            residual = x
            x = x.permute(1, 2, 0) ## [Time x Batch x Embedding] => [Batch x Embedding x Time]
            x = self.activation(self.cnn(x))
            x = x.permute(2, 0, 1) ## [Batch x Embedding x Time] => [Time x Batch x Embedding]
            x = self.cnn_ln(residual + x)
        else:
            x = x.permute(1, 2, 0) ## [Time x Batch x Embedding] => [Batch x Embedding x Time]
            x = self.activation(self.cnn(x))
            x = x.permute(2, 0, 1) ## [Batch x Embedding x Time] => [Time x Batch x Embedding]
                
        residual = x
        x = self.ff(x)
        x = self.ff_ln(residual + x)
            
        return x
        
class DistillatedLanguageModel(nn.Module):
    
    def __init__(self, args):
        super(DistillatedLanguageModel, self).__init__()
        self.embed = nn.Embedding(args.vocab_size, args.hidden_size, padding_idx=0)
        
        self.entry_block = ResConvBlock(input_size=args.hidden_size, output_size=args.blocks_size)
        
        self.blocks = nn.ModuleList(modules=[ResConvBlock(input_size=args.blocks_size, output_size=args.blocks_size) for _ in range(args.n_blocks)])
        
        self.out_project = nn.Linear(in_features=args.blocks_size, out_features=args.hidden_size)
        
    def forward(self, x):
        x = self.embed(x)
        
        x = self.entry_block(x)
        
        for block in self.blocks:
            x = block(x)
        
        x = self.out_project(x)
        
        return F.linear(x, self.embed.weight)

In [89]:
model = DistillatedLanguageModel(args)

In [90]:
text, lens = next(iter(train_dl))

In [91]:
text.shape

torch.Size([74, 64])

In [92]:
text.shape

torch.Size([74, 64])

In [93]:
student_out = model(text)

In [94]:
_, teacher_out = teacher(text, lens)

In [96]:
student_out.shape, teacher_out.shape

(torch.Size([74, 64, 32001]), torch.Size([74, 64, 32001]))

In [95]:
temp = 10.0

In [97]:
student_probs = F.softmax(student_out/temp, -1)
teacher_probs = F.softmax(teacher_out/temp, -1)

In [98]:
F.mse_loss(student_probs, teacher_probs)

tensor(5.3630e-10, grad_fn=<MeanBackward0>)

In [127]:
F.mse_loss(

tensor(5.3630e-10, grad_fn=<MeanBackward0>)