# AutoCasing - BERT

This notebook has one of the approaches used for training a NN for performing automatic casing for all lower case text.

In [1]:
%%time
import torch
import re
import random
import torch
from torch import nn
from torch.nn import functional as F

import numpy as np
from torch.utils.data import DataLoader
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModel, BertTokenizerFast

from tqdm import trange
from sklearn.metrics import f1_score, confusion_matrix

from daily import *
from collate_methods import *

CPU times: user 1.31 s, sys: 763 ms, total: 2.07 s
Wall time: 2.13 s


In [2]:
%%time
dataset = load_dataset("reddit", cache_dir="../")
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased", cache_dir = "../hf-cache/")
# device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
# model = AutoModel.from_pretrained("bert-base-cased", cache_dir = "../hf-cache/", ).to(device)

Using custom data configuration default
Reusing dataset reddit (../reddit/default/1.0.0/98ba5abea674d3178f7588aa6518a5510dc0c6fa8176d9653a3546d5afcb3969)


CPU times: user 204 ms, sys: 22.6 ms, total: 227 ms
Wall time: 5.6 s


In [3]:
dataset

DatasetDict({
    train: Dataset({
        features: ['author', 'body', 'normalizedBody', 'subreddit', 'subreddit_id', 'id', 'content', 'summary'],
        num_rows: 3848330
    })
})

In [8]:
train_split_ratio = 0.99
seed = 4
dd = dataset["train"].train_test_split(
  test_size = 1 - train_split_ratio,
  train_size = train_split_ratio,
  seed = seed
)
dstrain = dd["train"]
dstest = dd["test"]

Loading cached split indices for dataset at ../reddit/default/1.0.0/98ba5abea674d3178f7588aa6518a5510dc0c6fa8176d9653a3546d5afcb3969/cache-e5dc09252e8e1642.arrow and ../reddit/default/1.0.0/98ba5abea674d3178f7588aa6518a5510dc0c6fa8176d9653a3546d5afcb3969/cache-015e7bd21d0fb866.arrow


In [9]:
dstrain, dstest

(Dataset({
     features: ['author', 'body', 'normalizedBody', 'subreddit', 'subreddit_id', 'id', 'content', 'summary'],
     num_rows: 3809846
 }),
 Dataset({
     features: ['author', 'body', 'normalizedBody', 'subreddit', 'subreddit_id', 'id', 'content', 'summary'],
     num_rows: 38484
 }))

In [10]:
# [tokenizer.decode(x) for x in tokenizer("Hello World!")["input_ids"]]
tokenizer.tokenize("Hello World", add_special_tokens = True)

['[CLS]', 'Hello', 'World', '[SEP]']

In [6]:
%%time
model.eval()
with torch.no_grad():
    x = tokenizer(
        "Luca Brasi held a gun to his head, and my father assured"
        " him that either his brains or his signature would be on"
        " the contract. That’s a true story. - Michael Corleone, Godfather",
        return_tensors = "pt"
    )
    out = model(**{k:v.to(device) for k,v in x.items()})

CPU times: user 389 ms, sys: 149 ms, total: 537 ms
Wall time: 536 ms


In [7]:
out.last_hidden_state.shape

torch.Size([1, 46, 768])

In [8]:
model.config

BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.5.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}

In [9]:
%%time

from autocase.collate_methods import *    

# NOTE: since these collator functions sit outside the dataset it is difficult
# to know what exactly will be the final batch size generally 25% smaller than
# the batch_size given as an input to the model.

COLLATE_METHODS = get_collate_fns(tokenizer, 512)
collate_fn = COLLATE_METHODS["fast_binary_flag"]

# collate_fn = FastBinaryCollater(tokenizer, 512)
    
loader = DataLoader(
    dstrain,
    batch_size=64,
    collate_fn=collate_fn,
    pin_memory = True if torch.cuda.is_available() else False,
    shuffle = False,
    num_workers = 6
)
print("Loader init ...")

for i, x in enumerate(loader):
    print(torch.sum(x["labels"]) / torch.numel(x["labels"]))
    if i == 100:
        break

Loader init ...


Token indices sequence length is longer than the specified maximum sequence length for this model (600 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (615 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1465 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (766 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (851 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th

tensor(0.0949)
tensor(0.1015)
tensor(0.0876)
tensor(0.0824)
tensor(0.0921)
tensor(0.0957)
tensor(0.0906)
tensor(0.0941)
tensor(0.0898)
tensor(0.0900)
tensor(0.1102)
tensor(0.1002)
tensor(0.1070)
tensor(0.0961)
tensor(0.0874)
tensor(0.0943)
tensor(0.0843)
tensor(0.0965)
tensor(0.0954)
tensor(0.1089)
tensor(0.0915)
tensor(0.1004)
tensor(0.1052)
tensor(0.0946)
tensor(0.0808)
tensor(0.0995)
tensor(0.1000)
tensor(0.1029)
tensor(0.0906)
tensor(0.0981)
tensor(0.0959)
tensor(0.0969)
tensor(0.1043)
tensor(0.0914)
tensor(0.1023)
tensor(0.1000)
tensor(0.0952)
tensor(0.1028)
tensor(0.0799)
tensor(0.1018)
tensor(0.1104)
tensor(0.0948)
tensor(0.1025)
tensor(0.0957)
tensor(0.0916)
tensor(0.1083)
tensor(0.0892)
tensor(0.0974)
tensor(0.0855)
tensor(0.1040)
tensor(0.1026)
tensor(0.0926)
tensor(0.0944)
tensor(0.1074)
tensor(0.0811)
tensor(0.0905)
tensor(0.0843)
tensor(0.0766)
tensor(0.1031)
tensor(0.1125)
tensor(0.1021)
tensor(0.0960)
tensor(0.0991)
tensor(0.1013)
tensor(0.0863)
tensor(0.0877)
tensor(0.0

In [10]:
auto_cased = []
for s,d in zip(x["input_ids"], x["labels"]):
    _s = []
    s = tokenizer.tokenize(
        tokenizer.decode(s, skip_special_tokens = True),
        add_special_tokens=False
    )
    for i,t in enumerate(s):
        if i in d:
            _s.append(t.capitalize())
        else:
            _s.append(t)
    auto_cased.append(_s)
    
fseq = ""
for t in _s:
    fseq += t if t[0] == "#" else " " + t
fseq = re.sub(r"#", "", fseq.strip())
fseq = re.sub(r"\s'\s", "'", fseq)
fseq = re.sub(r"\[\s", "[", fseq)
fseq = re.sub(r"\s\]", "]", fseq)
fseq

'As The title says i haven\'t had pretty much any sexual contact except for maybe a kiss since i was a college freshman back when i was a naive 18 year old after a few setbacks and a long battle with depression i finally decided to get back into the game maybe i\'ve lost my touch or i\'m just too busy with work but i don\'t have time for a relationship and just want a pal who can make my toes curl do guys do that too ? i\'m just you\'re average 21 year old i won\'t say i\'m something special or a model i have an average to slim build kind of have resting bitch face i\'ve tried fixing it sorry and as for my " measurements " i would say pretty average as well but doesn\'t every guy think that ? i don\'t really have an set conditions except be clean please sane please don\'t stalk me and hopefully be relatively close looking to host or uber and not looking to make this experience expensive for either of us as for age and appearance i have no preference but if you\'re older i might like yo

In [13]:
tseq = ""
for t in x["target_str"][-1]:
    tseq += t if t[0] == "#" else " " + t
tseq = re.sub(r"#", "", tseq.strip())
tseq = re.sub(r"\s'\s", "'", tseq)
tseq = re.sub(r"\[\s", "[", tseq)
tseq = re.sub(r"\s\]", "]", tseq)
tseq

'As the title says I haven\'t had pretty much any sexual contact except for maybe a kiss since I was a college freshman back when I was a naive 18 year old After a few setbacks and a long battle with depression I finally decided to get back into the game Maybe I\'ve lost my touch or I\'m just too busy with work but I don\'t have time for a relationship and just want a pal who can make my toes curl do guys do that too ? I\'m just you\'re average 21 year old I won\'t say I\'m something special or a model I have an average to slim build Kind of have resting bitch face I\'ve tried fixing it sorry and as for my " measurements " I would say pretty average as well but doesn\'t every guy think that ? I don\'t really have an set conditions except be clean please Sane please don\'t stalk me And hopefully be relatively close Looking to host or Uber and not looking to make this experience expensive for either of us As for age and appearance I have no preference but if you\'re older I might like yo

In [18]:
x.keys()

dict_keys(['labels', 'target_str', 'input_ids', 'attention_mask'])

In [5]:
# this is not perfect but still okay for this demo case also we'll add a whole bunch of regex's
# and get all this sorted out. So now we make the model for this case

class AutoCaseModel(nn.Module):
    def __init__(self, hf_backbone = "bert-base-cased", cache_dir = "../hf-cache/"):
        super().__init__()
        assert "bert" in hf_backbone.lower(), "Supports only BERT Models"

        self.tokenizer = AutoTokenizer.from_pretrained(hf_backbone, cache_dir = cache_dir)
        self.device = torch.device("cuda:0") if torch.cuda.is_available() else "cpu"
        self.model = AutoModel.from_pretrained(hf_backbone, cache_dir = cache_dir).to(self.device)
        self.model.eval()
        self.model_config = self.model.config
        
        self.head = nn.Linear(self.model.config.hidden_size, 1).to(self.device)
        
    def forward(self, input_ids, attention_mask, labels = None, *args, **kwargs):
        with torch.no_grad():
            attention_mask = attention_mask.to(self.device)
            out = self.model(**{
                "input_ids": input_ids.to(self.device),
                "attention_mask": attention_mask
            }).last_hidden_state
        logits_2d = torch.sigmoid(self.head(out))
        
        if labels is not None:
            # load labels on the model device
            labels = labels.to(self.device)
            
            # we need to remove the tokens where attention_mask != 1
            attn_flat = attention_mask.contiguous().view(-1)
            attn_mask = attn_flat > 0
            logits = logits_2d.contiguous().view(-1)[attn_mask]
            labels = labels.contiguous().view(-1)[attn_mask].float()
            
            # create a weight matrix for the loss
            weight_mat = torch.zeros_like(logits).float()
            weight_mat[labels == 1] = 4
            weight_mat += 1
            loss = F.binary_cross_entropy(
                logits,
                labels,
                weight_mat
            )
            
            # calculate accuracy right in here
            # if x > 0.5 -> 1
            logits_act = torch.zeros_like(logits)
            logits_act[logits > 0.5] = 1
            corr = logits_act == labels
            corr = torch.sum(corr) / torch.numel(corr)

            f1 = f1_score(labels.cpu(), logits_act.cpu())
            cm = confusion_matrix(labels.cpu(), logits_act.cpu())
            
            return (logits_2d, loss, corr.item(), f1, cm)
        return logits_2d

In [6]:
%%time
model = AutoCaseModel()

CPU times: user 22.2 s, sys: 1.41 s, total: 23.6 s
Wall time: 28.7 s


In [35]:
logits, loss, acc, f1, cm = model(**x)
logits.shape, loss, acc, f1, cm



(torch.Size([64, 512, 1]),
 tensor(1.2001, device='cuda:0', grad_fn=<BinaryCrossEntropyBackward>),
 0.7595351934432983,
 0.12797502926258292,
 array([[13791,  1688],
        [ 2782,   328]]))

In [28]:
# fuck the warnings
import warnings
warnings.filterwarnings("ignore")

# create optimizer
optim = torch.optim.Adam(model.head.parameters())

# create train and test data loader
COLLATE_METHODS = get_collate_fns(model.tokenizer, 512)
collate_fn = COLLATE_METHODS["fast_binary_flag"]
train_dl = DataLoader(
    dstrain,
    batch_size=128,
    collate_fn=collate_fn,
    pin_memory = True if torch.cuda.is_available() else False,
    shuffle = True,
    num_workers = 6
)
dl = iter(train_dl) # and create multiple iterators
test_dl = DataLoader(
    dstest,
    batch_size = 256,
    collate_fn = collate_fn,
    pin_memory = True if torch.cuda.is_available() else False,
    num_workers = 6
)
print("Loader init ...")

# create lists
train_acc = []
train_loss = []
train_cm = []
train_f1 = []

test_acc = []
test_loss = []
test_cm = []
test_f1 = []

# create controls
n_steps = 100000
n_steps_test = 1000

pbar = trange(n_steps)
for gs in pbar:
    # train
    try:
        x = next(dl)
    except StopIteration:
        dl = iter(train_dl)
        x = next(dl)
    
    # update progress string
    if gs:
        # tn = accuracy anyways
        _desc = " | ".join([
            f"Tr Ac: {np.mean(train_acc[-50:]):.3f}",
            f"Tr Lo: {np.mean(train_loss[-50:]):.3f}",
            f"Tr F1: {np.mean(train_f1[-50:]):.3f}",
            f"Te Ac: {np.mean(test_acc[-50:]):.3f}",
            f"Te Lo: {np.mean(test_loss[-50:]):.3f}",
            f"Te F1: {np.mean(test_f1[-50:]):.3f}",
        ])
        pbar.set_description(_desc)
        
    # perform training
    logits, loss, acc, f1, cm = model(**x)
    
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    train_acc.append(acc)
    train_loss.append(loss.item())
    train_f1.append(f1)
    train_cm.append(cm)
    
    # test when the time comes
    if gs and gs % n_steps_test == 0:
        model.eval()
        _test_acc = []
        _test_loss = []
        _test_cm = []
        _test_f1 = []
        for x in test_dl:
            logits, loss, acc, f1, cm = model(**x)
            _test_acc.append(acc)
            _test_loss.append(loss.item())
            _test_cm.append(cm)
            _test_f1.append(f1)
        test_cm.append(np.mean(_test_cm, 0))
        test_acc.append(np.mean(_test_acc))
        test_loss.append(np.mean(_test_loss))
        test_f1.append(np.mean(_text_f1))
        model.train()

  0%|          | 0/100000 [00:00<?, ?it/s]

Loader init ...


Token indices sequence length is longer than the specified maximum sequence length for this model (2378 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (517 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (719 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (544 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (826 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for th

KeyboardInterrupt: 

In [49]:
with torch.no_grad():
    string = '''and while india struggles to shake off the virus the developed world
    is taking off the masks the countries who've been able to vaccinate a sizeable number
    of people those living in the united states for instance they can ditch the masks at
    most public places now if they have taken both the shots'''
    tokenized = tokenizer(string, return_tensors = "pt")
    out = model(**tokenized)
    out[out > 0.5] = 1
    out[out <= 0.5] = 0
    out = out.view(-1).cpu().tolist()
    up_tok = [i for i,x in enumerate(out) if x == 1.]
up_tok

[0, 1, 2, 3, 4, 10, 18, 19, 36, 37, 41, 42, 43, 44, 45, 55]

In [59]:
# [tokenizer.decode(x) for x in tokenized["input_ids"].view(-1)]
def autocase(s, d):
    auto_cased = []
    s = s.view(-1).tolist()[1:]
    s = [tokenizer.decode(x) for x in s]
    for i,t in enumerate(s):
        if i in d:
            auto_cased.append(t.capitalize())
        else:
            auto_cased.append(t)
#     auto_cased = auto_cased[1:]

    fseq = ""
    for t in auto_cased:
        fseq += t if t[0] == "#" else " " + t
    fseq = re.sub(r"#", "", fseq.strip())
    fseq = re.sub(r"\s'\s", "'", fseq)
    fseq = re.sub(r"\[\s", "[", fseq)
    fseq = re.sub(r"\s\]", "]", fseq)
    return fseq

autocase(tokenized["input_ids"], up_tok)

"And While India Struggles to shake off the virus The developed world is taking off the masks The Countries who've been able to vaccinate a sizeable number of people Those Living in the united States For Instance They Can ditch the masks at most public places now if They have taken both the shots [SEP]"