In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 5.3 MB/s 
[?25hCollecting tokenizers!=0.11.3,>=0.11.1
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 30.6 MB/s 
[?25hCollecting sacremoses
  Downloading sacremoses-0.0.47-py2.py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 47.7 MB/s 
[?25hCollecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 24.7 MB/s 
Collecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.4.0-py3-none-any.whl (67 kB)
[K     |████████████████████████████████| 67 kB 5.1 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyy

In [2]:
from transformers import MobileBertForPreTraining
from transformers import MobileBertTokenizer
import torch
import pandas as pd

import requests
from tqdm import tqdm
import random

In [3]:
model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
tokenizer = MobileBertTokenizer.from_pretrained("google/mobilebert-uncased")

Downloading:   0%|          | 0.00/847 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/140M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

In [4]:
data = pd.read_csv("chemprot_training_abstracts.tsv", 
                   sep='\t',
                   header=0,
                   names=['Article identifier','Title of the article','Abstract of the article']
                  )

In [5]:
#sent = [sent for text in data.values[:,2] for sent in text.split('.') if sent!='' or sent!='.' ]

bag = [item for sentence in data.values[:,2] for item in sentence.split('.') if item != ''][:64]
bag_size = len(bag)

In [6]:
import random

sentence_a = []
sentence_b = []
label = []

for paragraph in data.values[:,2]:
    sentences = [
        sentence for sentence in paragraph.split('.') if sentence != ''
    ]
    num_sentences = len(sentences)
    if num_sentences > 1:
        start = random.randint(0, num_sentences-2)
        # 50/50 whether is IsNextSentence or NotNextSentence
        if random.random() >= 0.5:
            # this is IsNextSentence
            sentence_a.append(sentences[start])
            sentence_b.append(sentences[start+1])
            label.append(0)
        else:
            index = random.randint(0, bag_size-1)
            # this is NotNextSentence
            sentence_a.append(sentences[start])
            sentence_b.append(bag[index])
            label.append(1)


In [7]:
inputs = tokenizer(sentence_a, 
                   sentence_b, 
                   return_tensors='pt',
                   max_length=512, 
                   truncation=True, 
                   padding='max_length')

inputs['next_sentence_label'] = torch.LongTensor([label]).T
inputs['labels'] = inputs.input_ids.detach().clone()


In [8]:
torch.manual_seed(42)

rand = torch.rand(inputs.input_ids.shape)
mask_arr = (rand <= 0.15)*(inputs.input_ids !=101)*(inputs.input_ids != 102)*(inputs.input_ids !=0) 

In [9]:
# маскируем токены

for i in range(inputs.input_ids.shape[0]):
  for j in range(inputs.input_ids[i].shape[0]):
    if mask_arr[i][j]:
      inputs.input_ids[i][j] = 103

In [10]:
class OurDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
    def __len__(self):
        return len(self.encodings.input_ids)

dataset = OurDataset(inputs)
loader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True)

In [4]:
device = torch.device('cuda')
model.train()
model.to(device)


MobileBertForPreTraining(
  (mobilebert): MobileBertModel(
    (embeddings): MobileBertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 512)
      (token_type_embeddings): Embedding(2, 512)
      (embedding_transformation): Linear(in_features=384, out_features=512, bias=True)
      (LayerNorm): NoNorm()
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): MobileBertEncoder(
      (layer): ModuleList(
        (0): MobileBertLayer(
          (attention): MobileBertAttention(
            (self): MobileBertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=512, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): MobileBertSelfOutput(
              (dense): Linear(in_features=128, out_fea

In [16]:
from transformers import AdamW
optim = AdamW(model.parameters(), lr= 5e-5)



In [17]:
epochs = 2

for epoch in range(epochs):
    # setup loop with TQDM and dataloader
    loop = tqdm(loader, leave=True)
    for batch in loop:
        # initialize calculated gradients (from prev step)

        optim.zero_grad()

        # pull all tensor batches required for training
        input_ids = batch['input_ids'].to(device)
        token_type_ids = batch['token_type_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        next_sentence_label = batch['next_sentence_label'].to(device)
        labels = batch['labels'].to(device)
        # process
        outputs = model(input_ids, attention_mask=attention_mask,
                        token_type_ids=token_type_ids,
                        next_sentence_label=next_sentence_label,
                        labels=labels)
        # extract loss
        
        loss = outputs.loss
        # calculate loss for every parameter that needs grad update
        loss.backward()
        # update parameters
        optim.step()
        # print relevant info to progress bar
        loop.set_description(f'Epoch {epoch}')
        loop.set_postfix(loss=loss.item())


  """
Epoch 0: 100%|██████████| 255/255 [03:54<00:00,  1.09it/s, loss=0.908]
Epoch 1: 100%|██████████| 255/255 [03:52<00:00,  1.10it/s, loss=0.755]
