https://mccormickml.com/2019/07/22/BERT-fine-tuning/

## Imports and Downloading Data

In [1]:
import tensorflow as tf
import torch

# Get the GPU device name.
device_name = tf.test.gpu_device_name()

# The device name should look like the following:
if device_name == '/device:GPU:0':
    print('Found GPU at: {}'.format(device_name))
else:
    raise SystemError('GPU device not found')

SystemError: GPU device not found

In [None]:
import torch

# If there's a GPU available...
if torch.cuda.is_available():    

    # Tell PyTorch to use the GPU.    
    device = torch.device("cuda")

    print('There are %d GPU(s) available.' % torch.cuda.device_count())

    print('We will use the GPU:', torch.cuda.get_device_name(0))

# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

In [3]:
!pip3 install wget



In [3]:
import wget
import os

print('Downloading dataset...')

# The URL for the dataset zip file.
url = 'https://competitions.codalab.org/my/datasets/download/4db8bf21-def7-4a86-99f5-7b23d5691bb3'

# Download the file (if we haven't already)
if not os.path.exists('multi-fc/'):
    !mkdir multi-fc
    wget.download(url, 'multi-fc/multi-fc.zip')
    !unzip multi-fc/multi-fc.zip -d multi-fc/

Downloading dataset...


## Reading in Data

In [35]:
import pandas as pd
import csv

# Load the dataset into a pandas dataframe.
df = pd.read_csv("multi-fc/train.tsv", delimiter='\t', header=None, quoting=csv.QUOTE_NONE, \
                 names= ['claimID', 'claim', 'label', 'claimURL', 'reason', 'categories', 'speaker', \
                  'checker', 'tags', 'articleTitle', 'publishDate', 'claimDate', 'entities'])

# Report the number of sentences.
print('Number of training sentences: {:,}\n'.format(df.shape[0]))

# Display 2 random rows from the data.
df.sample(2)

Number of training sentences: 27,940



Unnamed: 0,claimID,claim,label,claimURL,reason,categories,speaker,checker,tags,articleTitle,publishDate,claimDate,entities
9879,pose-01284,“One of my first acts as president will be to ...,stalled,https://www.politifact.com/truth-o-meter/promi...,,trumpometer,Donald Trump,,,Establish a commission on radical Islam,2017-01-17T09:10:41,,['Islam']
20501,pose-00490,"""Support the principle of network neutrality t...",promise kept,https://www.politifact.com/truth-o-meter/promi...,,obameter,Barack Obama,,,Support network neutrality on the Internet,2010-01-07T13:27:01,,['None']


In [36]:
# Contains empty claim
indexNames = df[ df['claimID'] == 'bove-00197' ].index
 
# Delete these row indexes from dataFrame
df.drop(indexNames , inplace=True)
sentences = df.claim.values
labels = df.label.values

### Adding Evidence Snippets to Claims

In [None]:
'''
pre_instances[] is a list of claim+snippet pairs

claimsnippet_labels[] is a list of the labels (since every claim has one label and is being expanded to claim+label 
pairs), essentially just an expanded list of the original labels

[SEP] token is used not as a BERT separator token, but so that when we tokenize we can properly split up the 
claim and snippet to pass into encode_plus as separate arguments (see tokenizer section for context)

[UNK] token in the exception is used for claims that do not have evidence snippets, BERT can handle this inherently


'''

In [40]:
pre_instances = []
claimsnippet_labels = []
count = 0
for a in range(len(list(df.claim))):
    claim,claimID,label = list(df.claim)[a], list(df.claimID)[a], list(df.label)[a]
    try:
        f=open("multi-fc/snippets/{claimID}".format(claimID=claimID), "r")
        for line in f.readlines():
            split = line.split("\t")
            snippet = split[2]
            pre_instance =  claim +" [SEP] "+snippet
            pre_instances.append(pre_instance)
            claimsnippet_labels.append(label)
    except FileNotFoundError:
            pre_instance = claim + "[SEP]" + "[UNK]"
            pre_instances.append(pre_instance)
            claimsnippet_labels.append(label)

    

In [41]:
pre_instances[0]

'"Six out of 10 of the highest unemployment rates are also in so-called right to work states." [SEP] May 8, 2013 ... Ron Maag and Kristina Roegner, claiming that "six out of 10 of the highest  unemployment rates are also in so-called right to work states.'

### Encoding Labels and Importing Tokenizer

In [42]:
from sklearn import preprocessing

le = preprocessing.LabelEncoder()
elongated_labels = le.fit_transform(elongated_labels)

In [None]:
!pip3 install transformers

In [43]:
from transformers import BertTokenizer
# Load the BERT tokenizer.
print('Loading BERT tokenizer...')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

Loading BERT tokenizer...


### Viewing Example

In [44]:
# Print the original sentence.
print(' Original: ', pre_instances[0])

# Print the sentence split into tokens.
print('Tokenized: ', tokenizer.tokenize(pre_instances[0]))

# Print the sentence mapped to token ids.
print('Token IDs: ', tokenizer.convert_tokens_to_ids(tokenizer.tokenize(pre_instances[0])))

 Original:  "Six out of 10 of the highest unemployment rates are also in so-called right to work states." [SEP] May 8, 2013 ... Ron Maag and Kristina Roegner, claiming that "six out of 10 of the highest  unemployment rates are also in so-called right to work states.
Tokenized:  ['"', 'six', 'out', 'of', '10', 'of', 'the', 'highest', 'unemployment', 'rates', 'are', 'also', 'in', 'so', '-', 'called', 'right', 'to', 'work', 'states', '.', '"', '[SEP]', 'may', '8', ',', '2013', '.', '.', '.', 'ron', 'ma', '##ag', 'and', 'kristina', 'roe', '##gne', '##r', ',', 'claiming', 'that', '"', 'six', 'out', 'of', '10', 'of', 'the', 'highest', 'unemployment', 'rates', 'are', 'also', 'in', 'so', '-', 'called', 'right', 'to', 'work', 'states', '.']
Token IDs:  [1000, 2416, 2041, 1997, 2184, 1997, 1996, 3284, 12163, 6165, 2024, 2036, 1999, 2061, 1011, 2170, 2157, 2000, 2147, 2163, 1012, 1000, 102, 2089, 1022, 1010, 2286, 1012, 1012, 1012, 6902, 5003, 8490, 1998, 28802, 20944, 10177, 2099, 1010, 6815, 20

### Tokenizer

Now that our instances are pre_processed, we can pass them into the BERT tokenizer

In [None]:
'''
every "sent" in pre_sentences is broken up into its constituent claim and snippet using the dummu SEP token
      AKA .split("SEP")
'''

In [45]:
# Tokenize all of the sentences and map the tokens to thier word IDs.
input_ids = []
attention_masks = []

# For every sentence...
for sent in pre_instances:
    
    claim, snippet = sent.split("[SEP]")[0], sent.split("[SEP]")[1]
   
    encoded_dict = tokenizer.encode_plus(
                        claim, #claim to encode
                        snippet,# snippet to encode
                        add_special_tokens = True, # Add '[CLS]' and '[SEP]'
                        max_length = 512,           # Pad & truncate all sentences.
                        pad_to_max_length = True,
                        return_attention_mask = True,   # Construct attn. masks.
                        return_tensors = 'pt',     # Return pytorch tensors.
                    )

    # Add the encoded sentence to the list.    
    input_ids.append(encoded_dict['input_ids'])
    
    # And its attention mask (simply differentiates padding from non-padding).
    attention_masks.append(encoded_dict['attention_mask'])

# Convert the lists into tensors.
input_ids = torch.cat(input_ids, dim=0)
attention_masks = torch.cat(attention_masks, dim=0)
elongated_labels = torch.tensor(elongated_labels)

# Print sentence 0, now as a list of IDs.
print('Original: ', pre_instances[0])
print('Token IDs:', input_ids[0])

Original:  "Six out of 10 of the highest unemployment rates are also in so-called right to work states." [SEP] May 8, 2013 ... Ron Maag and Kristina Roegner, claiming that "six out of 10 of the highest  unemployment rates are also in so-called right to work states.
Token IDs: tensor([  101,  1000,  2416,  2041,  1997,  2184,  1997,  1996,  3284, 12163,
         6165,  2024,  2036,  1999,  2061,  1011,  2170,  2157,  2000,  2147,
         2163,  1012,  1000,   102,  2089,  1022,  1010,  2286,  1012,  1012,
         1012,  6902,  5003,  8490,  1998, 28802, 20944, 10177,  2099,  1010,
         6815,  2008,  1000,  2416,  2041,  1997,  2184,  1997,  1996,  3284,
        12163,  6165,  2024,  2036,  1999,  2061,  1011,  2170,  2157,  2000,
         2147,  2163,  1012,   102,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
            0,     0,