## Corpus

In [1]:
from corpus import download_and_unzip, catgeories, corpus_root, read_text_files

In [2]:
download_and_unzip()

Already downloaded and extracted!


In [3]:
import os
from tqdm import tqdm


reviews = []
labels = []

# we can't use the previous tokenizers here
# idx 0 -> neg, 1 -> pos
for idx, cat in enumerate(catgeories):
    path = os.path.join(corpus_root, cat)
    texts = read_text_files(path)

    for i in tqdm(range(len(texts)), desc="prepare_corpus"):
        text = texts[i]
        reviews.append(text)
        labels.append(idx)

prepare_corpus: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2295732.90it/s]
prepare_corpus: 100%|██████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2376376.20it/s]


In [4]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    reviews, labels, random_state=42, train_size=0.8
)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.8, random_state=42)

## Dataloader


In [5]:
import numpy as np
import jax
from jax import numpy as jnp

In [6]:
from transformers import AutoTokenizer


model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name) 

In [7]:
import torch
from torch.utils.data import Dataset

# custom dataset
class PolarityReviewDataset(Dataset):

    def __init__(self, reviews, labels, tokenizer):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer
        self.MAX_LEN = 128

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        review = self.reviews[idx]
        label = self.labels[idx]

        # encode review text
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=self.MAX_LEN,
            truncation=True,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="np"
        )

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding['attention_mask'].flatten(),
            "label": np.array(label)
        }

In [8]:
training_dataset = PolarityReviewDataset(x_train, y_train, tokenizer)
val_dataset = PolarityReviewDataset(x_val, y_val, tokenizer)

In [9]:
for td in training_dataset:
    print(td)
    break

{'input_ids': array([  101, 15349,  1159,   131,   178,  1138,  1309,   117,  1518,
        1562,  2065,  1114,  1103,  3223,   119,   178,  1274,   112,
         189,  1221,  1725,   117,  1541,   119,  3983,   112,   189,
        1458,  1106,  4031,  1122,  1149,  1113,  1888,   117,  3983,
         112,   189,  1151,  1120,  1313,  1103,  6823,  1122,  1108,
        1113,  2443,   189,  1964,   117,  1105,  1122,  1108,  1315,
        1677,  1106,  2797,  1103,  1314,  1159,  1122,  1108,  1113,
        1103,  1992,   118,  3251,   119,  1177,  1268,  1146,  1524,
         117,   178,   112,  1325,  5890,  1115,   178,  1274,   112,
         189,  1221,  1184,  1103, 26913,   178,   112,   182,  2520,
        1164,   117,  1133,  1303,  2947,   119,   119,   119,  1110,
         189,  5168,  7770,  1103,  2065,  1114,  1103,  3223,  1104,
        1103,  1997,   112,   188,   136,  2654,  1115,   112,   188,
        1280,   170,  1376,  2113,  1315,  1677,   119,  1112,  1363,
      

In [10]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)
    
    
# collate fn
# https://github.com/google/jax/issues/3382
def collate_fn(batch):
    if isinstance(batch[0], jnp.ndarray):
        return jnp.stack(batch)
    elif isinstance(batch[0], (tuple, list)):
        return type(batch[0])(collate_fn(samples) for samples in zip(*batch))
    else:
        return jnp.asarray(batch)

In [11]:
from torch.utils.data import DataLoader

batch_size = 16

# loader from custom dataset
train_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size, 
                          pin_memory=True, collate_fn=numpy_collate, drop_last=True)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, 
                        pin_memory=True, collate_fn=numpy_collate)

In [12]:
for td in train_loader:
    print(td)
    break

[{'input_ids': array([  101,  1103,  1263,  3041,  1363, 11683,   113,   187,   114,
         1143, 19944,   188, 14750,  1643,  1793,  1122,  1105,  2604,
          119,  1256,   185, 22962,  1161,  1105, 18608,  5837,  1162,
         1189,  1126,  2661,  1133,  2204,  3596,  1113,  1123,  1218,
          118, 26612,  6661,   119,  1649,   117,   176,  9561,  1161,
         5358,  9356,  1180,  1304,  1218,  1561,  1103,  1148,  3085,
         1895,  1821, 26237,  1389,  2130,  2168,  2851,  1114,  1103,
         1263,  3041,  1363, 11683,   117,   170,  3073, 21341, 10771,
         1361,  1133, 12170,  4106,  2168, 11826,  2002,  1118,  1123,
         2252,   117,  1231, 15863,  5871, 20754,   119,  5358,  9356,
         2399, 21718,  1399,  7702, 11019,  2042,   117,   170,   182,
         2285,  1183, 10786,  1278,  3218,  1105,  1534,  2133,  5628,
         1178,  1301,  1171,  2022,  1201,   119,  1114,  1103,  1494,
         1104, 15380,  1193, 23589,  2029,  9140, 26410,  1732,

## Model

In [26]:
import flax
from flax import linen as nn

from transformers import FlaxBertModel

# define the model

class SentiBert(nn.Module):
    @nn.compact  
    def __call__(self, x):
        input_ids, attention_mask = x
        bert = FlaxBertModel.from_pretrained(model_name)
        out = bert(input_ids=input_ids, attention_mask=attention_mask)
        out = out.pooler_output
        linear = nn.Dense(features=2)
        out = linear(out)
        
        return out

In [14]:
masterkey = jax.random.PRNGKey(42)
model_key, data_key = jax.random.split(masterkey)

In [27]:
model = SentiBert()


# dummy input
dummy = tokenizer.encode_plus("this is some text", max_length=128, padding="max_length", 
                              truncation=True, return_tensors="jax")

params = model.init(model_key, (dummy["input_ids"], dummy["attention_mask"]))

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-cased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [16]:
jax.tree_map(lambda x:x.shape, params)

FrozenDict({
    params: {
        Dense_0: {
            bias: (1,),
            kernel: (768, 1),
        },
    },
})

In [17]:
@jax.jit
def cross_entropy(logits, labels):
    return -jnp.sum(labels * jax.nn.log_softmax(logits, axis=-1), axis=-1)

In [18]:
def loss_fn(params, inputs, labels):
    logits = model.apply(params, inputs)
    loss = cross_entropy(logits, labels).mean()
    return loss

In [19]:
import optax

optimizer = optax.adam(learning_rate=2e-5)
optimizer_state = optimizer.init(params)
loss_fn_grad = jax.value_and_grad(loss_fn) 

In [20]:
@jax.jit
def train_step(params, opt_state, inputs, labels):
    loss, grads = loss_fn_grad(params, inputs, labels)
    updates, opt_state = optimizer.update(grads, opt_state)
    updated_params = optax.apply_updates(params, updates)
    
    return updated_params, opt_state, loss

In [23]:
epochs = 2
print_counter = 0

losses = list()

for e in range(epochs):
    for batch in tqdm(train_loader):
        input_ids = list()
        attention_masks = list()
        labels = list()
        for entry in batch:
            input_ids.append(entry["input_ids"])
            attention_masks.append(entry["attention_mask"])
            labels.append(entry["label"])
        
        params, optimizer_state, loss = train_step(params, optimizer_state, 
                                         (jnp.array(input_ids), jnp.array(attention_masks)), 
                                          jnp.array(labels))
        
        losses.append(loss)
        

100%|███████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:02<00:00, 34.22it/s]
100%|███████████████████████████████████████████████████████████████████████████████████████| 80/80 [00:02<00:00, 33.81it/s]


#### This process isn't actually correct

1. Labels aren't one hot coded (since log_softmax, get the idea!)
2. Need to do something with the batch loading


But the model trains still. Just need to fix those parts