In [1]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import TransformerEncoder, TransformerEncoderLayer



## 이 ipython  파일은 그냥 전체 흐름만
> Transformer의 공부는 세부모듈들이 중요해서 각각 어떻게 만들어지는지 쪼개서 파일 만들어 공부 계획

In [78]:
class PositionalEncoding(nn.Module):
    '''
    cos 및 sin의 함수로 position을 매핑해주기 위한 함수
    저자가 주장한 이런 position value의 특징
    -> 모델 내에서 seq의 길이에 관계없이 제한된범위의 사이 값을 반환해야한다.
    -> 같은 거리차의 두 토큰관겨 x(t1), x(t1+k) 및 x(t2), x(t2+k)의 pos(x)의 차이는 같아야한다.
    이 두가지 요건을 만족하기 위해 저자는 sin, cos를 바탕으로 구현하여서 0~1의 분포 time dist에 따른 값차이를 갖는 positional encoding이 성립
    '''
    def __init__(self, d_model, dropout, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + self.pe[:x.size(0), :] # seq의 길이가 당연히 max len인 5000이 안되니까 pe를 잘라준다
        return self.dropout(x)

In [79]:
class TransformerModel(nn.Module):
    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout=dropout)
        encoder_layers =TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp) # 이거 사실상 bias없는 linear인데 괜히 이름이 있어보임
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken) # embeding의 역변환인데 linear인게 embeding이 linear일 뿐인데 이름따로 있는게 신기함
        
        self.init_weights()
        
    def generate_square_subsequent_mask(self, sz):
        mask = (torch.tril(torch.ones(sz, sz)))
        mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask
        
    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)
        
    def forward(self, src, src_mask):
        src = self.encoder(src) # 일단 dim을 맞춰준다 그냥 linear
        src = self.pos_encoder(src) # pos vector를 더해준다
        output = self.transformer_encoder(src, src_mask) # 이부분은 transformer encoder layer를 일단 공부해야함 잠시 킵
        output = self.decoder(output) # 어째든 나오는 벡터가 있는데 이게 dim은 그대로이니 원상복귀할거임
        return output
        
        

## 텍스트 데이터를 조금 공부해보자
- torchtext를 바탕으로 실험을 할것임
  - 다 그렇듯 .dataset에 각종 데이터들의 dataloader가 들어있음 일단 슬적 구경이나 해보자
  - nlp의 가장 큰 특징은 tokenizer의 존재라 생각함. 단순하게 word나 char를 쓰는게 아니라 잘라서 sentencepiece를 사용
  - vocab의 관리가 중요한데 torchtext는 이런 vocab관련 모듈이 있음 일단 따라가본다.
  
- 그냥 따라가려했는데 자꾸 torchtext version문제로 dataset이 자꾸 맛이가서 미래 version torch text에만 있는 함수를 1.8.0에서 쓰려고 그냥 긁어왔다. 아래 2개블럭은 그냥 무시해도됨

In [42]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator, Vocab

In [32]:
import functools
import inspect
import os
import io
import json
import torch
from torchtext.utils import (
    validate_file,
    download_from_url,
    extract_archive,
    unicode_csv_reader,
)
import codecs
try:
    import defusedxml.ElementTree as ET
except ImportError:
    import xml.etree.ElementTree as ET

_CACHE_DIR = os.path.expanduser('~/.torchtext/cache')


def _clean_xml_file(f_xml):
    f_txt = os.path.splitext(f_xml)[0]
    with codecs.open(f_txt, mode='w', encoding='utf-8') as fd_txt:
        root = ET.parse(f_xml).getroot()[0]
        for doc in root.findall('doc'):
            for e in doc.findall('seg'):
                fd_txt.write(e.text.strip() + '\n')


def _clean_tags_file(f_orig):
    xml_tags = [
        '<url', '<keywords', '<talkid', '<description', '<reviewer',
        '<translator', '<title', '<speaker', '<doc', '</doc'
    ]
    f_txt = f_orig.replace('.tags', '')
    with codecs.open(f_txt, mode='w', encoding='utf-8') as fd_txt, \
            io.open(f_orig, mode='r', encoding='utf-8') as fd_orig:
        for line in fd_orig:
            if not any(tag in line for tag in xml_tags):
                # TODO: Fix utf-8 next line mark
                #                fd_txt.write(l.strip() + '\n')
                #                fd_txt.write(l.strip() + u"\u0085")
                #                fd_txt.write(l.lstrip())
                fd_txt.write(line.strip() + '\n')


def _create_data_from_json(data_path):
    with open(data_path) as json_file:
        raw_json_data = json.load(json_file)['data']
        for layer1 in raw_json_data:
            for layer2 in layer1['paragraphs']:
                for layer3 in layer2['qas']:
                    _context, _question = layer2['context'], layer3['question']
                    _answers = [item['text'] for item in layer3['answers']]
                    _answer_start = [item['answer_start'] for item in layer3['answers']]
                    if len(_answers) == 0:
                        _answers = [""]
                        _answer_start = [-1]
                    # yield the raw data in the order of context, question, answers, answer_start
                    yield (_context, _question, _answers, _answer_start)


def _create_data_from_iob(data_path, separator='\t'):
    with open(data_path, encoding="utf-8") as input_file:
        columns = []
        for line in input_file:
            line = line.strip()
            if line == "":
                if columns:
                    yield columns
                columns = []
            else:
                for i, column in enumerate(line.split(separator)):
                    if len(columns) < i + 1:
                        columns.append([])
                    columns[i].append(column)
        if len(columns) > 0:
            yield columns


def _read_text_iterator(path):
    with io.open(path, encoding="utf8") as f:
        for row in f:
            yield row


def _create_data_from_csv(data_path):
    with io.open(data_path, encoding="utf8") as f:
        reader = unicode_csv_reader(f)
        for row in reader:
            yield int(row[0]), ' '.join(row[1:])


def _check_default_set(split, target_select, dataset_name):
    # Check whether given object split is either a tuple of strings or string
    # and represents a valid selection of options given by the tuple of strings
    # target_select.
    if isinstance(split, str):
        split = (split,)
    if isinstance(target_select, str):
        target_select = (target_select,)
    if not isinstance(split, tuple):
        raise ValueError("Internal error: Expected split to be of type tuple.")
    if not set(split).issubset(set(target_select)):
        raise TypeError('Given selection {} of splits is not supported for dataset {}. Please choose from {}.'.format(
            split, dataset_name, target_select))
    return split


def _wrap_datasets(datasets, split):
    # Wrap return value for _setup_datasets functions to support singular values instead
    # of tuples when split is a string.
    if isinstance(split, str):
        if len(datasets) != 1:
            raise ValueError("Internal error: Expected number of datasets is not 1.")
        return datasets[0]
    return datasets


def _find_match(match, lst):
    """
    Searches list of strings and returns first entry that partially or fully
    contains the given string match.
    """
    for element in lst:
        if match in element:
            return element
    return None


def _dataset_docstring_header(fn, num_lines=None, num_classes=None):
    """
    Returns docstring for a dataset based on function arguments.
    Assumes function signature of form (root='.data', split=<some tuple of strings>, **kwargs)
    """
    argspec = inspect.getfullargspec(fn)
    if not (argspec.args[0] == "root" and
            argspec.args[1] == "split"):
        raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))
    default_split = argspec.defaults[1]

    if not (isinstance(default_split, tuple) or isinstance(default_split, str)):
        raise ValueError("default_split type expected to be of string or tuple but got {}".format(type(default_split)))

    header_s = fn.__name__ + " dataset\n"

    if isinstance(default_split, tuple):
        header_s += "\nSeparately returns the {} split".format("/".join(default_split))

    if isinstance(default_split, str):
        header_s += "\nOnly returns the {} split".format(default_split)

    if num_lines is not None:
        header_s += "\n\nNumber of lines per split:"
        for k, v in num_lines.items():
            header_s += "\n    {}: {}\n".format(k, v)

    if num_classes is not None:
        header_s += "\n\nNumber of classes"
        header_s += "\n    {}\n".format(num_classes)

    args_s = "\nArgs:"
    args_s += "\n    root: Directory where the datasets are saved."
    args_s += "\n        Default: .data"

    if isinstance(default_split, tuple):
        args_s += "\n    split: split or splits to be returned. Can be a string or tuple of strings."
        args_s += "\n        Default: {}""".format(str(default_split))

    if isinstance(default_split, str):
        args_s += "\n     split: Only {default_split} is available."
        args_s += "\n         Default: {default_split}.format(default_split=default_split)"

    return "\n".join([header_s, args_s]) + "\n"


def _add_docstring_header(docstring=None, num_lines=None, num_classes=None):
    def docstring_decorator(fn):
        old_doc = fn.__doc__
        fn.__doc__ = _dataset_docstring_header(fn, num_lines, num_classes)
        if docstring is not None:
            fn.__doc__ += docstring
        if old_doc is not None:
            fn.__doc__ += old_doc
        return fn
    return docstring_decorator


def _wrap_split_argument_with_fn(fn, splits):
    """
    Wraps given function of specific signature to extend behavior of split
    to support individual strings. The given function is expected to have a split
    kwarg that accepts tuples of strings, e.g. ('train', 'valid') and the returned
    function will have a split argument that also accepts strings, e.g. 'train', which
    are then turned single entry tuples. Furthermore, the return value of the wrapped
    function is unpacked if split is only a single string to enable behavior such as
    train = AG_NEWS(split='train')
    train, valid = AG_NEWS(split=('train', 'valid'))
    """
    argspec = inspect.getfullargspec(fn)
    if not (argspec.args[0] == "root" and
            argspec.args[1] == "split" and
            argspec.varargs is None and
            argspec.varkw is None and
            len(argspec.kwonlyargs) == 0 and
            len(argspec.annotations) == 0
            ):
        raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

    @functools.wraps(fn)
    def new_fn(root=_CACHE_DIR, split=splits, **kwargs):
        result = []
        for item in _check_default_set(split, splits, fn.__name__):
            result.append(fn(root, item, **kwargs))
        return _wrap_datasets(tuple(result), split)

    new_sig = inspect.signature(new_fn)
    new_sig_params = new_sig.parameters
    new_params = []
    new_params.append(new_sig_params['root'].replace(default='.data'))
    new_params.append(new_sig_params['split'].replace(default=splits))
    new_params += [entry[1] for entry in list(new_sig_params.items())[2:]]
    new_sig = new_sig.replace(parameters=tuple(new_params))
    new_fn.__signature__ = new_sig

    return new_fn


def _wrap_split_argument(splits):
    def new_fn(fn):
        return _wrap_split_argument_with_fn(fn, splits)
    return new_fn


def _create_dataset_directory(dataset_name):
    def decorator(func):
        argspec = inspect.getfullargspec(func)
        if not (argspec.args[0] == "root" and
                argspec.args[1] == "split" and
                argspec.varargs is None and
                argspec.varkw is None and
                len(argspec.kwonlyargs) == 0 and
                len(argspec.annotations) == 0
                ):
            raise ValueError("Internal Error: Given function {} did not adhere to standard signature.".format(fn))

        @functools.wraps(func)
        def wrapper(root=_CACHE_DIR, *args, **kwargs):
            new_root = os.path.join(root, dataset_name)
            if not os.path.exists(new_root):
                os.makedirs(new_root)
            return func(root=new_root, *args, **kwargs)

        return wrapper

    return decorator


def _download_extract_validate(root, url, url_md5, downloaded_file, extracted_file, extracted_file_md5,
                               hash_type="sha256"):
    root = os.path.abspath(root)
    downloaded_file = os.path.abspath(downloaded_file)
    extracted_file = os.path.abspath(extracted_file)
    if os.path.exists(extracted_file):
        with open(os.path.join(root, extracted_file), 'rb') as f:
            if validate_file(f, extracted_file_md5, hash_type):
                return extracted_file

    dataset_tar = download_from_url(url, path=os.path.join(root, downloaded_file),
                                    hash_value=url_md5, hash_type=hash_type)
    extracted_files = extract_archive(dataset_tar)
    assert os.path.exists(extracted_file), "extracted_file [{}] was not found in the archive [{}]".format(extracted_file, extracted_files)

    return extracted_file


class _RawTextIterableDataset(torch.utils.data.IterableDataset):
    """Defines an abstraction for raw text iterable datasets.
    """

    def __init__(self, description, full_num_lines, iterator):
        """Initiate the dataset abstraction.
        """
        super(_RawTextIterableDataset, self).__init__()
        self.description = description
        self.full_num_lines = full_num_lines
        self._iterator = iterator
        self.num_lines = full_num_lines
        self.current_pos = None

    def __iter__(self):
        return self

    def __next__(self):
        if self.current_pos == self.num_lines - 1:
            raise StopIteration
        item = next(self._iterator)
        if self.current_pos is None:
            self.current_pos = 0
        else:
            self.current_pos += 1
        return item

    def __len__(self):
        return self.num_lines

    def pos(self):
        """
        Returns current position of the iterator. This returns None
        if the iterator hasn't been used yet.
        """
        return self.current_pos

    def __str__(self):
        return self.description

In [33]:
import logging
from torchtext.utils import download_from_url, extract_archive
# from torchtext.data.datasets_utils import (
#     _RawTextIterableDataset,
#     _wrap_split_argument,
#     _add_docstring_header,
#     _find_match,
#     _create_dataset_directory,
#     _read_text_iterator,
# )

URL = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'

MD5 = '542ccefacc6c27f945fb54453812b3cd'

NUM_LINES = {
    'train': 36718,
    'valid': 3760,
    'test': 4358,
}

DATASET_NAME = "WikiText2"


@_add_docstring_header(num_lines=NUM_LINES)
@_create_dataset_directory(dataset_name=DATASET_NAME)
@_wrap_split_argument(('train', 'valid', 'test'))
def WikiText2(root, split):
    dataset_tar = download_from_url(URL, root=root, hash_value=MD5, hash_type='md5')
    extracted_files = extract_archive(dataset_tar)
    path = _find_match(split, extracted_files)
    logging.info('Creating {} data'.format(split))
    return _RawTextIterableDataset(DATASET_NAME,
                                   NUM_LINES[split], _read_text_iterator(path))

In [40]:
import os
tokenizer = get_tokenizer('basic_english')
train_iter = WikiText2(root='/raid/jskim/data/nlp',split='train')

In [22]:
tokenizer('hello my name is jskim') # 딱 봐도 음절분리기

['hello', 'my', 'name', 'is', 'jskim']

In [26]:
import torchtext
torchtext.__version__

'0.8.0a0+0f911ec'

In [46]:
vocab = build_vocab_from_iterator(map(tokenizer, train_iter)) # special 이 안먹어서 따로 해줘야한다..

36718lines [00:01, 21509.11lines/s]


In [54]:
vocab['my']

448

## vocab의 call 과정이 바꼇다..
- 1.8.0 에서는 단일 token만 넣을 수 있는데 최신 버젼에서는 리스트를 통으로 넣을 수 있게 바뀌엇나보다.

In [55]:
def data_process(raw_text_iter):
    data = [torch.tensor([vocab[t] for t in tokenizer(item)], dtype=torch.long) for item in raw_text_iter]
    return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

In [89]:
train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

In [90]:
print(train_data[:5])
print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

tensor([  10, 3850, 3870,  882,   10])
torch.Size([2049990])
torch.Size([214417])
torch.Size([241859])


In [91]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    # 데이터셋을 bsz 파트들로 나눕니다.
    nbatch = data.size(0) // bsz
    # 깔끔하게 나누어 떨어지지 않는 추가적인 부분(나머지들) 은 잘라냅니다.
    data = data.narrow(0, 0, nbatch * bsz) # 0 번 dim을 자를것이고, 0번 부터시작해서 nbatch * bsz개 만큼 살릴것이다 라는 뜻
    # 데이터에 대하여 bsz 배치들로 동등하게 나눕니다.
    data = data.view(bsz, -1).t().contiguous() # bsz개씩 묶어서 묶어지는대로 만들것이고. transpose 하고 이거가 정립될 수 있게 contiguous해준다 (contiguout안해주면 접근이 좀 맛이가서 view가 안먹히게된다.)
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

In [93]:
print(train_data[:5])
print(train_data.shape)
print(val_data.shape)
print(test_data.shape)

tensor([[   10,    60,   565,   224,   444, 13628,     3,   540,  2873,  2465,
             0,   314,  4514,     2,     6,    48,    67, 11653,  2436,     2],
        [ 3850,    13,   301,  6303,  3990,  1931, 10560,   452,     5,     8,
             3,  1512, 10116,   943,  2440,   573,     2,    48,    31,  1991],
        [ 3870,   316,    20,    30,   940,     3,    11,  2140,  4917, 16616,
           236,     4,    14,     8,    25,    18, 13738,    98,  7721,     5],
        [  882,    68,   808,  5403,     7,    39, 28189,    26,     3,    78,
             8,  2395,    18,   517,    15, 16404,  3715,  4619,    13,  1109],
        [   10,   197,  6042,   191,   219, 11777,    18,     2,  1201,     3,
             0,    11,   592,    41,  6005,     3,    51,     4,  3132,  3782]],
       device='cuda:0')
torch.Size([102499, 20])
torch.Size([21441, 10])
torch.Size([24185, 10])


## get_batch
- 트랜스포머 학습을 하기 위한 입력 & 타겟 묶음을 생성.
 - 소스 데이터를 bptt길이의 덩어리로 바꾸는데, 언어모델은 다음 단어가 필요하다.
 - bptt의 값이 2라면 i=0일때 다음 2개의 변수를 갖게된다.

In [66]:
bptt = 35
def get_batch(source, i):
    '''
    그냥 뒤에꺼랑 묶는것, i에서 시작한다는 의미
    '''
    seq_len = min(bptt, len(source) -1 - i)
    data = source[i:i+seq_len]
    target = source[i+1: i + 1 + seq_len].reshape(-1)
    return data, target



In [80]:
ntokens = len(vocab) # 단어 사전(어휘집)의 크기
emsize = 200 # 임베딩 차원
nhid = 200 # nn.TransformerEncoder 에서 피드포워드 네트워크(feedforward network) 모델의 차원
nlayers = 2 # nn.TransformerEncoder 내부의 nn.TransformerEncoderLayer 개수
nhead = 2 # 멀티헤드 어텐션(multi-head attention) 모델의 헤드 개수
dropout = 0.2 # 드랍아웃(dropout) 값
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)


In [81]:
import time

criterion = nn.CrossEntropyLoss()
lr = 5.0 # 학습률
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)


In [82]:
def train():
    model.train() # 학습 모드를 시작합니다.
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # 평가 모드를 시작합니다.
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

In [83]:
best_val_loss = float("inf")
epochs = 3 # 에포크 수
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 13.98 | loss  8.49 | ppl  4849.98
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch  8.04 | loss  7.49 | ppl  1783.24
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch  8.40 | loss  7.20 | ppl  1344.36
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch  7.88 | loss  7.12 | ppl  1231.12
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch  7.36 | loss  7.08 | ppl  1192.80
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch  7.18 | loss  7.05 | ppl  1148.62
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch  8.06 | loss  7.01 | ppl  1107.12
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch  8.25 | loss  7.00 | ppl  1099.75
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch  8.25 | loss  7.00 | ppl  1091.58
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch  8.22 | loss  6.98 | ppl  1079.41
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch  7.92 | loss  6.95 | ppl  1046.07
| epoch   1 |  2400/ 

In [109]:
data, targets = get_batch(train_data, 0)
# print(data)
print(data.shape)
embeded = model.encoder(data)
print(embeded.shape)
posed = model.pos_encoder(embeded)
print(posed.shape)

src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
output = self.transformer_encoder(posed, src_mask)

print(output.shape)
# print(targets.shape)

# print(src_mask)
# print(src_mask.shape)

torch.Size([35, 20])
torch.Size([35, 20, 200])
torch.Size([35, 20, 200])


In [103]:
print(model)

TransformerModel(
  (pos_encoder): PositionalEncoding(
    (dropout): Dropout(p=0.2, inplace=False)
  )
  (transformer_encoder): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=200, out_features=200, bias=True)
        )
        (linear1): Linear(in_features=200, out_features=200, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
        (linear2): Linear(in_features=200, out_features=200, bias=True)
        (norm1): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((200,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.2, inplace=False)
        (dropout2): Dropout(p=0.2, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): Linear(in_features=200, out_features=200, bias=True)
        )
        (linear1): Linear(in_features=200, out_features=20