In [8]:
from transformers import BertTokenizer, BertModel
from transformers import pipeline
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

In [33]:
import os

folder_path = './corpus-processed'
authors_data = {}

for filename in os.listdir(folder_path):
    args = filename.split(",")
    if args[0] not in authors_data:
        authors_data[args[0]] = []

    file_path = os.path.join(folder_path, filename)
    if os.path.isfile(file_path):
        with open(file_path, 'r') as file:
            file_lines = [line.rstrip('\n') for line in file]
            authors_data[args[0]].extend(file_lines)

In [41]:
authors = list(authors_data.keys())
author_ids = {}
for i in range(len(authors)):
    author_ids[authors[i]] = i
print(author_ids)

{'shakespeare': 0, 'jonson': 1, 'fletcher': 2, 'ford': 3, 'rowley': 4, 'middleton': 5, 'massinger': 6, 'dekker': 7, 'webster': 8}


In [51]:
import json

data = []
author_data_tokenized = {}
for author in authors:
    author_data = authors_data[author]
    print(author, len(author_data))

    tokenized_sentences = [tokenizer.tokenize(s) for s in author_data]
    author_data_tokenized[author] = tokenized_sentences

with open('tokenized_author_data.json', 'w') as file:
    json.dump(author_data_tokenized, file)

shakespeare 59780
jonson 22195
fletcher 19846
ford 5299
rowley 3136
middleton 21782
massinger 12615
dekker 7174
webster 4867


In [90]:
MAX_SEQUENCE_LEN = 128
bert_inputs = []
bert_inputs_readable = []
bert_input_masks = []
data_labels = []

for author in authors:
    sentences = author_data_tokenized[author]
    label = author_ids[author]

    current_input = ["CLS"]
    for s in sentences:
        if len(s) + len(current_input) <= MAX_SEQUENCE_LEN - 1:
            current_input.extend(s)
        else:
            current_input.append("[SEP]")
            mask = [1 for _ in range(len(current_input))]

            while len(current_input) != MAX_SEQUENCE_LEN:
                current_input.append("[PAD]")
                mask.append(0)

            bert_inputs.append(tokenizer.convert_tokens_to_ids(current_input))
            bert_inputs_readable.append(current_input)
            bert_input_masks.append(mask)
            data_labels.append(label)
            current_input = ["CLS"]

print(f"Total number of inputs: {len(bert_inputs)}")

Total number of inputs: 19043


In [113]:
import torch
from torch.utils.data import TensorDataset, random_split
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler


x_inputs = torch.tensor(bert_inputs)
x_masks = torch.tensor(bert_input_masks)
y_labels = torch.tensor(data_labels)

print("INPUT + LABEL SHAPES: ", x_inputs.shape, x_masks.shape, y_labels.shape)

dataset = TensorDataset(x_inputs, x_masks, y_labels)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print('{:>5,} training samples'.format(train_size))
print('{:>5,} validation samples'.format(val_size))

batch_size = 16
train_dataloader = DataLoader(
    train_dataset,
    sampler=RandomSampler(train_dataset),
    batch_size=batch_size
)
validation_dataloader = DataLoader(
    val_dataset,
    sampler=SequentialSampler(val_dataset),
    batch_size=batch_size
)

INPUT + LABEL SHAPES:  torch.Size([19043, 128]) torch.Size([19043, 128]) torch.Size([19043])
17,138 training samples
1,905 validation samples


In [114]:
from transformers import BertForSequenceClassification, AdamW, BertConfig

model = BertForSequenceClassification.from_pretrained(
    "bert-base-uncased", # Use the 12-layer BERT model, with an uncased vocab.
    num_labels=len(authors),
    output_attentions=False,
    output_hidden_states=False,
)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [115]:
from transformers import get_linear_schedule_with_warmup

optimizer = AdamW(model.parameters(), lr=2e-5, eps=1e-8)

epochs = 2 # shoudl be 2-4
total_steps = len(train_dataloader) * epochs
scheduler = get_linear_schedule_with_warmup(optimizer,  num_warmup_steps = 0, num_training_steps = total_steps)



In [116]:
import numpy as np
import time
import datetime
import random


def flat_accuracy(preds, labels):
    pred_flat = np.argmax(preds, axis=1).flatten()
    labels_flat = labels.flatten()
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

def format_time(elapsed):
    elapsed_rounded = int(round((elapsed)))
    return str(datetime.timedelta(seconds=elapsed_rounded))

In [119]:
# This training code is based on the `run_glue.py` script here:
# https://github.com/huggingface/transformers/blob/5bfcd0485ece086ebcbed2d008813037968a9e58/examples/run_glue.py#L128

seed_val = 42
device = 'cuda' if torch.cuda.is_available() else 'cpu'

random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
torch.cuda.manual_seed_all(seed_val)

training_stats = []
total_t0 = time.time()

for epoch_i in range(0, epochs):
    
    # ========================================
    #               Training
    # ========================================
    print("")
    print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))
    print('Training...')

    t0 = time.time()
    total_train_loss = 0
    model.train()

    for step, batch in enumerate(train_dataloader):
        if step % 40 == 0 and not step == 0:
            elapsed = format_time(time.time() - t0)
            print('  Batch {:>5,}  of  {:>5,}.    Elapsed: {:}.'.format(step, len(train_dataloader), elapsed))

        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        model.zero_grad()        

        outputs = model(b_input_ids, 
                             token_type_ids=None, 
                             attention_mask=b_input_mask, 
                             labels=b_labels)
        loss = outputs.loss
        logits = outputs.logits
        total_train_loss += loss.item()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()

    avg_train_loss = total_train_loss / len(train_dataloader)   
    training_time = format_time(time.time() - t0)

    print("")
    print("  Average training loss: {0:.2f}".format(avg_train_loss))
    print("  Training epcoh took: {:}".format(training_time))
        
    # ========================================
    #               Validation
    # ========================================
    print("")
    print("Running Validation...")

    t0 = time.time()
    model.eval()

    total_eval_accuracy = 0
    total_eval_loss = 0
    nb_eval_steps = 0

    for batch in validation_dataloader:
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)
        
        with torch.no_grad():        
            outputs = model(b_input_ids, 
                                   token_type_ids=None, 
                                   attention_mask=b_input_mask,
                                   labels=b_labels)
            loss = outputs.loss
            logits = outputs.logits
            
        total_eval_loss += loss.item()

        logits = logits.detach().cpu().numpy()
        label_ids = b_labels.to('cpu').numpy()
        total_eval_accuracy += flat_accuracy(logits, label_ids)
        
    avg_val_accuracy = total_eval_accuracy / len(validation_dataloader)
    print("  Accuracy: {0:.2f}".format(avg_val_accuracy))

    avg_val_loss = total_eval_loss / len(validation_dataloader)
    validation_time = format_time(time.time() - t0)
    
    print("  Validation Loss: {0:.2f}".format(avg_val_loss))
    print("  Validation took: {:}".format(validation_time))

    training_stats.append(
        {
            'epoch': epoch_i + 1,
            'Training Loss': avg_train_loss,
            'Valid. Loss': avg_val_loss,
            'Valid. Accur.': avg_val_accuracy,
            'Training Time': training_time,
            'Validation Time': validation_time
        }
    )

print("")
print("Training complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))


Training...


KeyboardInterrupt: 