In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [5]:
from utils.copied_utils import (
    compute_input_and_target_lengths,
    DataCollatorForT5MLM,
    tokenize_function,
    DataCollatorForNI,
)
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import LocalDataset
import torch
import numpy as np


class UInt16(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.uint16)


_encodings['uint16'] = UInt16


class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = LocalDataset(local=local)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

In [6]:
dataset = DatasetFixed(local='/home/ubuntu/mosaic-nanot5-512')

In [3]:
before_mask_input_length, target_length = compute_input_and_target_lengths(
    inputs_length=512,
    noise_density=0.15,
    mean_noise_span_length=3.0,
)
before_mask_input_length, target_length

(568, 114)

In [9]:
from transformers import AutoTokenizer, AutoConfig

tokenizer = AutoTokenizer.from_pretrained('out-base-1.1')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
config = AutoConfig.from_pretrained('out-base-1.1')

In [12]:
data_collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=512,
    target_length=114,
    pad_token_id=config.pad_token_id,
)

In [30]:
b = [dataset[i] for i in range(3)]
b = data_collator(b)
b

{'input_ids': tensor([[    1,  3536,  3146,  ...,  2309, 32072,     2],
         [ 9584,  3146, 18376,  ..., 10035, 32072,     2],
         [  743,  2886,   790,  ...,   530, 32072,     2]]),
 'labels': tensor([[32099,   227,   170,   101,   124, 18963,   101,   120, 32098,  1676,
           2802,  2309, 32097,  2309,  7159,   295, 21329,    15, 32096, 29589,
            352,  5090,  1810, 25920, 32095,    12, 32094,     5,   449,  6342,
          32093,     5,   647, 32092,  2244, 11414, 32091,  2602,   313,  1305,
           1427,  2200,   939, 32090,   352, 32089,   901,  2150,  2309, 11130,
             15,  2309,  9404, 32088,  7159, 32087,  2309,  3146, 32086,  2309,
           1627, 32085,  1206,  9584, 32084, 10404,   586,  3065,   275, 32083,
           3146,  5777,    12,   313,   847,   327, 32082, 12849, 32081, 21847,
            383, 32080,   287, 32079,    17,  1733, 20348,  8466, 32078,   758,
            492, 32077,    17,  3536,  3146, 32076,  5378, 32075,   449,   526

In [31]:
from transformers import T5ForConditionalGeneration
from pytorch_lightning import LightningModule

In [32]:
class Module(LightningModule):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            './out-base-1.1'
        )
        self.model = T5ForConditionalGeneration.from_pretrained(
            './out-base-1.1',
            config=config,
        )

In [33]:
!ls -lh logs/base

total 8.4G
-rw-r--r-- 1 ubuntu ubuntu 2.8G Apr 14 03:39 'model-epoch=00-step=1000.ckpt'
-rw-r--r-- 1 ubuntu ubuntu 2.8G Apr 14 04:19 'model-epoch=00-step=2000.ckpt'
-rw-r--r-- 1 ubuntu ubuntu 2.8G Apr 14 04:50 'model-epoch=00-step=3000.ckpt'


In [34]:
model = Module()
weights = model.state_dict()
old_weights = torch.load('logs/base/model-epoch=00-step=3000.ckpt',
                             map_location=torch.device('cpu'))['state_dict'].items()

In [35]:
for k, v in old_weights:
    new_k = k.replace('._orig_mod', '')
    print(k, new_k)
    weights[new_k] = v

model.shared.weight model.shared.weight
model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight
model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight
model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight
model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight
model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight
model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.w

In [36]:
model.load_state_dict(weights)

<All keys matched successfully>

In [42]:
o = model.model(**b)
o

Seq2SeqLMOutput(loss=tensor(3.0237, grad_fn=<NllLossBackward0>), logits=tensor([[[-0.9907,  2.8476,  1.9752,  ..., -2.8603, -2.6576, -1.2661],
         [-1.8009,  3.8409,  3.7604,  ..., -4.5560, -2.4641, -3.2539],
         [-1.6907,  0.0464,  7.7736,  ..., -2.7157, -4.4345, -3.7622],
         ...,
         [-1.6395,  1.4647,  4.4011,  ..., -2.5878, -0.5852, -3.5563],
         [-2.6389,  2.4625,  5.2757,  ..., -6.3230, -4.1409, -2.0036],
         [ 0.3584,  2.3890, 35.3724,  ..., -1.7918, -3.8099, -0.6635]],

        [[-0.6767,  2.5914,  2.1937,  ..., -2.3515, -2.3793, -1.3622],
         [-4.3737,  5.6130,  4.0856,  ..., -2.0719, -2.5653, -0.1653],
         [-3.6748, -0.1740,  2.9097,  ..., -2.9931, -3.7182, -3.8249],
         ...,
         [-1.5893,  0.0879,  4.5392,  ..., -2.3501, -1.0870, -5.1463],
         [-3.1135,  0.4900,  6.3222,  ..., -3.0063, -1.8968, -4.1890],
         [-0.3985,  2.4653, 27.7797,  ..., -2.1855, -2.1259, -2.1988]],

        [[-0.6180,  2.8315,  2.3604,  ..., -

In [48]:
tokenizer.decode(o.logits.argmax(-1)[0])

'<extra_id_99>�<extra_id_98>��<extra_id_98>��<extra_id_98> wilayah<extra_id_97><extra_id_97><extra_id_97> bahasa<extra_id_96> bahasa<extra_id_96>,<extra_id_96> orang bahasa<extra_id_95>angan)<extra_id_95>) yang" dan men<extra_id_93>" ialah<extra_id_92> menjadi dalam<extra_id_91> bahasa yang terdiri<extra_id_90> salah satu<extra_id_90> (<extra_id_89> bahasa<extra_id_88> bahasa Inggeris ( bahasa Melayu<extra_id_88> bahasa bahasa bahasa Melayu<extra_id_86> bahasa<extra_id_85><extra_id_85> "bahasa<extra_id_84>c bahasa<extra_id_83>ak<extra_id_83> Melayu).). dan ditut<extra_id_82> bahasa<extra_id_81>ong R<extra_id_80>or<extra_id_79>. Di Malaysia<extra_id_78><extra_id_78>orang Melayu<extra_id_77>. Di<extra_id_76><extra_id_76> Johor<extra_id_75> dan bahasaezakan bahasa bahasa bahasa<extra_id_74>bahasa<extra_id_74>en<extra_id_73><extra_id_73> lain<extra_id_72>-</s>'

In [49]:
tokenizer.decode(b['labels'][0])

'<extra_id_99>�ꤼ ꤸ<extra_id_98> bawah keluarga bahasa<extra_id_97> bahasa rasmi di Brunei,<extra_id_96> penutur (seramai 260<extra_id_95>)<extra_id_94>" untuk pent<extra_id_93>" atau<extra_id_92> digunakan mewakili<extra_id_91> waktu yang sama merupakan salah satu<extra_id_90> (<extra_id_89> lain ialah bahasa Inggeris, bahasa Cina<extra_id_88> rasmi<extra_id_87> bahasa Melayu<extra_id_86> bahasa Indonesia<extra_id_85> "bahasa<extra_id_84>bentuk vernak<extra_id_83> Melayu tempatan) yang ditut<extra_id_82> asli<extra_id_81>ulauan R<extra_id_80>or<extra_id_79>. Di selatan Thailand<extra_id_78>orang dari<extra_id_77>. Bahasa Melayu<extra_id_76> Melaka<extra_id_75> untuk membezakannya daripada bahasa-bahasa<extra_id_74>nolog<extra_id_73> terpisah<extra_id_72>-</s>'