In [2]:
import os
import sys
import random
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

In [3]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |▍                               | 10kB 7.4MB/s eta 0:00:01[K     |▉                               | 20kB 3.2MB/s eta 0:00:01[K     |█▎                              | 30kB 4.3MB/s eta 0:00:01[K     |█▊                              | 40kB 3.6MB/s eta 0:00:01[K     |██▏                             | 51kB 4.3MB/s eta 0:00:01[K     |██▋                             | 61kB 4.8MB/s eta 0:00:01[K     |███                             | 71kB 5.2MB/s eta 0:00:01[K     |███▍                            | 81kB 5.7MB/s eta 0:00:01[K     |███▉                            | 92kB 5.7MB/s eta 0:00:01[K     |████▎                           | 102kB 5.7MB/s eta 0:00:01[K     |████▊                           | 112kB 5.7MB/s eta 0:00:01[K     |█████▏                          | 122kB 5.7MB

In [4]:
from transformers import BertTokenizer
from transformers import BertForSequenceClassification
from transformers import BertConfig

In [5]:
!pip install -U spacy[cuda92]
!python -m spacy download en_core_web_sm
import spacy
import en_core_web_sm
spacy.prefer_gpu()
spacy_nlp = en_core_web_sm.load()

Collecting spacy[cuda92]
[?25l  Downloading https://files.pythonhosted.org/packages/10/b5/c7a92c7ce5d4b353b70b4b5b4385687206c8b230ddfe08746ab0fd310a3a/spacy-2.3.2-cp36-cp36m-manylinux1_x86_64.whl (9.9MB)
[K     |████████████████████████████████| 10.0MB 5.1MB/s 
Collecting thinc==7.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/10/ae/ef3ae5e93639c0ef8e3eb32e3c18341e511b3c515fcfc603f4b808087651/thinc-7.4.1-cp36-cp36m-manylinux1_x86_64.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 38.8MB/s 
Collecting cupy-cuda92<9.0.0,>=5.0.0b4; extra == "cuda92"
[?25l  Downloading https://files.pythonhosted.org/packages/df/ed/7aee0f78919d02b5f607f62a1abe9ca3a4a7c3bdc55099ed910a58d3972e/cupy_cuda92-8.0.0b4-cp36-cp36m-manylinux1_x86_64.whl (325.1MB)
[K     |████████████████████████████████| 325.1MB 32kB/s 
Installing collected packages: thinc, cupy-cuda92, spacy
  Found existing installation: thinc 7.4.0
    Uninstalling thinc-7.4.0:
      Successfully uninstalled thinc

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


In [7]:
DIR = "/content/drive/My Drive/ml_hw/NLP/question_generator/"
PRETRAINED_MODEL = 'bert-base-cased'
BATCH_SIZE = 16
SEQ_LENGTH = 512

tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL)

class QAEvalDataset(Dataset):
    def __init__(self, csv):
        self.df = pd.read_csv(csv, engine='python')
        self.transforms = [self.shuffle, self.corrupt]

    def __len__(self):
         return len(self.df)

    def __getitem__(self, idx): 
        _, question, answer = self.df.iloc[idx]
        label = random.choice([0, 1])

        if label == 0:
            question, answer = random.choice(self.transforms)(question, answer)

        encoded_data = tokenizer(
            text=question,
            text_pair=answer,
            pad_to_max_length=True, 
            max_length=SEQ_LENGTH,
            truncation=True,
            return_tensors="pt"
        )

        encoded_data['input_ids'] = torch.squeeze(encoded_data['input_ids'])
        encoded_data['token_type_ids'] = torch.squeeze(encoded_data['token_type_ids'])
        encoded_data['attention_mask'] = torch.squeeze(encoded_data['attention_mask'])
        return (encoded_data.to(device), torch.tensor(label).to(device))
    
    def shuffle(self, question, answer):
        shuffled_answer = answer
        while shuffled_answer == answer:
            shuffled_answer = self.df.sample(1)['answer'].item()
        return question, shuffled_answer
    
    def corrupt(self, question, answer):
        doc = spacy_nlp(question)
        if len(doc.ents) > 1:
            # Replace all entities in the sentence with the same thing
            copy_ent = str(random.choice(doc.ents))
            for ent in doc.ents:
                question = question.replace(str(ent), copy_ent)
        elif len(doc.ents) == 1:
            # Replace the answer with an entity from the question
            answer = str(doc.ents[0])
        else:
            question, answer = self.shuffle(question, answer)
        return question, answer


train_set = QAEvalDataset(os.path.join(DIR, 'qa_eval_train_2.csv')) 
train_loader = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True)
valid_set = QAEvalDataset(os.path.join(DIR, 'qa_eval_valid_2.csv')) 
valid_loader = DataLoader(valid_set, batch_size=BATCH_SIZE, shuffle=False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




In [8]:
LR = 0.001
EPOCHS = 10
LOG_INTERVAL = 500

model = BertForSequenceClassification.from_pretrained(PRETRAINED_MODEL)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=LR)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

In [9]:
SAVED_MODEL_PATH = "/content/drive/My Drive/ml_hw/NLP/question_generator/qa_eval_model_trained_2.pth"

def train():
    model.train()
    total_loss = 0.
    for batch_index, batch in enumerate(train_loader):
        data, labels = batch
        optimizer.zero_grad()
        output = model(**data, labels=labels)
        loss = output[0]
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        total_loss += loss.item()
        
        if batch_index % LOG_INTERVAL == 0 and batch_index > 0:
            cur_loss = total_loss / LOG_INTERVAL
            print('| epoch {:3d} | ' 
                  '{:5d}/{:5d} batches | '
                  'loss {:5.2f}'.format(
                    epoch, 
                    batch_index, len(train_loader), 
                    cur_loss))
            total_loss = 0

def evaluate(eval_model, data_loader):
    eval_model.eval()
    total_score = 0.
    with torch.no_grad():
        for batch_index, batch in enumerate(data_loader):
            data, labels = batch
            output = eval_model(**data, labels=labels)
            preds = np.argmax(output[1].cpu(), axis=1)
            total_score += (preds == labels.cpu()).sum()
    return total_score / (len(data_loader) * BATCH_SIZE)

def save(epoch, model_state_dict, optimizer_state_dict, loss):
    torch.save({
            'epoch': epoch,
            'model_state_dict': model_state_dict,
            'optimizer_state_dict': optimizer_state_dict,
            'best_loss': loss,
            }, SAVED_MODEL_PATH)

    print("| Model saved.")
    print_line()

def load():
    return torch.load(SAVED_MODEL_PATH)

def print_line():
    LINE_WIDTH = 60
    print('-' * LINE_WIDTH)

In [None]:
highest_accuracy = 0

accuracy = evaluate(model, valid_loader)
print_line()
print('| Before training | accuracy on valid set: {:5.2f}%'.format(accuracy))
print_line()

for epoch in range(1, EPOCHS + 1):

    train()
    accuracy = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | accuracy on valid set: {:5.2f}%'.format(
        epoch,
        accuracy)
    )
    print_line()

    if accuracy > highest_accuracy:
        highest_accuracy = accuracy
        save(
             epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             highest_accuracy
        )

------------------------------------------------------------
| Before training | accuracy on valid set:  0.66%
------------------------------------------------------------
| epoch   1 |   500/13007 batches | loss  0.58
| epoch   1 |  1000/13007 batches | loss  0.54
| epoch   1 |  1500/13007 batches | loss  0.51
| epoch   1 |  2000/13007 batches | loss  0.46
| epoch   1 |  2500/13007 batches | loss  0.42
| epoch   1 |  3000/13007 batches | loss  0.41
| epoch   1 |  3500/13007 batches | loss  0.40
| epoch   1 |  4000/13007 batches | loss  0.39
| epoch   1 |  4500/13007 batches | loss  0.37
| epoch   1 |  5000/13007 batches | loss  0.37
| epoch   1 |  5500/13007 batches | loss  0.36
| epoch   1 |  6000/13007 batches | loss  0.35
| epoch   1 |  6500/13007 batches | loss  0.34
| epoch   1 |  7000/13007 batches | loss  0.34
| epoch   1 |  7500/13007 batches | loss  0.34
| epoch   1 |  8000/13007 batches | loss  0.33
| epoch   1 |  8500/13007 batches | loss  0.32
| epoch   1 |  9000/13007 bat

In [None]:
checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
highest_accuracy = checkpoint['best_loss']
model.to(device)

for epoch in range(1, EPOCHS + 1):

    train()
    accuracy = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | accuracy on valid set: {:5.2f}%'.format(
        epoch,
        accuracy)
    )
    print_line()

    if accuracy > highest_accuracy:
        highest_accuracy = accuracy
        save(
             epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             highest_accuracy
        )

| epoch   1 |   500/13007 batches | loss  0.24
| epoch   1 |  1000/13007 batches | loss  0.23
| epoch   1 |  1500/13007 batches | loss  0.25
| epoch   1 |  2000/13007 batches | loss  0.24
| epoch   1 |  2500/13007 batches | loss  0.26
| epoch   1 |  3000/13007 batches | loss  0.24
| epoch   1 |  3500/13007 batches | loss  0.24
| epoch   1 |  4000/13007 batches | loss  0.25
| epoch   1 |  4500/13007 batches | loss  0.24
| epoch   1 |  5000/13007 batches | loss  0.24
| epoch   1 |  5500/13007 batches | loss  0.24
| epoch   1 |  6000/13007 batches | loss  0.22
| epoch   1 |  6500/13007 batches | loss  0.24
| epoch   1 |  7000/13007 batches | loss  0.24
| epoch   1 |  7500/13007 batches | loss  0.23
| epoch   1 |  8000/13007 batches | loss  0.24
| epoch   1 |  8500/13007 batches | loss  0.25
| epoch   1 |  9000/13007 batches | loss  0.24
| epoch   1 |  9500/13007 batches | loss  0.25
| epoch   1 | 10000/13007 batches | loss  0.22
| epoch   1 | 10500/13007 batches | loss  0.24
| epoch   1 |

In [None]:
checkpoint = load()
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
current_epoch = checkpoint['epoch']
highest_accuracy = checkpoint['best_loss']
model.to(device)

for epoch in range(1, EPOCHS + 1):

    train()
    accuracy = evaluate(model, valid_loader)
    print_line()
    print('| end of epoch {:3d} | accuracy on valid set: {:5.2f}%'.format(
        epoch,
        accuracy)
    )
    print_line()

    if accuracy > highest_accuracy:
        highest_accuracy = accuracy
        save(
             epoch, 
             model.state_dict(), 
             optimizer.state_dict(), 
             highest_accuracy
        )

| epoch   1 |   500/13007 batches | loss  0.21
| epoch   1 |  1000/13007 batches | loss  0.20
| epoch   1 |  1500/13007 batches | loss  0.20
| epoch   1 |  2000/13007 batches | loss  0.20
| epoch   1 |  2500/13007 batches | loss  0.21
| epoch   1 |  3000/13007 batches | loss  0.22
| epoch   1 |  3500/13007 batches | loss  0.21
| epoch   1 |  4000/13007 batches | loss  0.21
| epoch   1 |  4500/13007 batches | loss  0.21
| epoch   1 |  5000/13007 batches | loss  0.21
| epoch   1 |  5500/13007 batches | loss  0.20
| epoch   1 |  6000/13007 batches | loss  0.20
| epoch   1 |  6500/13007 batches | loss  0.21
| epoch   1 |  7000/13007 batches | loss  0.20
| epoch   1 |  7500/13007 batches | loss  0.21
| epoch   1 |  8000/13007 batches | loss  0.21
| epoch   1 |  8500/13007 batches | loss  0.21
| epoch   1 |  9000/13007 batches | loss  0.21
| epoch   1 |  9500/13007 batches | loss  0.21
| epoch   1 | 10000/13007 batches | loss  0.21
| epoch   1 | 10500/13007 batches | loss  0.21
| epoch   1 |