In [1]:
import jax as J
import jax.numpy as jnp

In [2]:
from jax.lib import xla_bridge
xla_bridge.get_backend().platform

'gpu'

In [3]:
import os
import sys
import urllib
import tarfile

corpus_url = "http://www.cs.cornell.edu/people/pabo/movie-review-data/review_polarity.tar.gz"

corpus_root = os.path.join(os.getcwd(), "review_polarity", "txt_sentoken")
catgeories = ["pos", "neg"]


def download_and_unzip():
    file_name = corpus_url.split("/")[-1]
    download_path = os.path.join(os.getcwd(), file_name)
    # where the zip will get extracted
    extracted_path = os.path.join(os.getcwd(), "review_polarity")

    if os.path.exists(extracted_path):
        print("Already downloaded and extracted!")
    else:
        # ============================================ download
        print("Downloading, sit tight!")

        def _progress(count, block_size, total_size):
            sys.stdout.write(
                f"\r>> Downloading {file_name} {float(count * block_size) / float(total_size) * 100.0}%")
            sys.stdout.flush()

        file_path, _ = urllib.request.urlretrieve(
            corpus_url, download_path, _progress)
        print()
        print(
            f"Successfully downloaded {file_name} {os.stat(file_path).st_size} bytes")

        # ======================================= unzip
        print()
        print("Unzipping ...")
        # create dir at extracted_path
        os.mkdir(extracted_path)
        tarfile.open(file_path, "r:gz").extractall(extracted_path)

        # =========================================== clean up
        # delete the downloaded zip file
        print("Deleting downloaded zip file")
        os.remove(file_path)

In [4]:
def read_text_files(path):
    file_list = os.listdir(path)
    texts = []

    for fname in file_list:
        fpath = os.path.join(path, fname)

        f = open(fpath, mode="r")
        lines = f.read()
        texts.append(lines)
        f.close()

    return texts

In [5]:
from tqdm import tqdm

download_and_unzip()

reviews = []
labels = []

# idx 0 -> neg, 1 -> pos
for idx, cat in enumerate(catgeories):
    path = os.path.join(corpus_root, cat)
    texts = read_text_files(path)

    for i in tqdm(range(len(texts)), desc="prepare_corpus"):
        text = texts[i]
        reviews.append(text)
        labels.append(idx)

print()
print(len(reviews))
print(len(labels))



Already downloaded and extracted!


prepare_corpus: 100%|████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2570039.22it/s]
prepare_corpus: 100%|████████████████████████████████████████████████████████████| 1000/1000 [00:00<00:00, 2235769.72it/s]


2000
2000





In [6]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

In [7]:
from sklearn.model_selection import train_test_split

x_train, x_test, y_train, y_test = train_test_split(
    reviews, labels, random_state=42, train_size=0.8
)

x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, train_size=0.8, random_state=42)

In [8]:
import torch
from torch.utils.data import Dataset

import numpy as np

# custom dataset
class PolarityReviewDataset(Dataset):

    def __init__(self, reviews, labels, tokenizer):
        self.reviews = reviews
        self.labels = labels
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        review = self.reviews[idx]
        label = self.labels[idx]

        # encode review text
        encoding = self.tokenizer.encode_plus(
            review,
            add_special_tokens=True,
            max_length=512,
            truncation=True,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="np"
        )

        return encoding, label

training_dataset = PolarityReviewDataset(x_train, y_train, tokenizer)
val_dataset = PolarityReviewDataset(x_val, y_val, tokenizer)

In [9]:
from torch.utils.data import DataLoader

batch_size = 16

## test
def collate_fn(data):
    inputs, labels = zip(*data)
    return inputs, jnp.array(labels)
    
    #print(data)

# loader from custom dataset
train_loader = DataLoader(training_dataset, shuffle=True, batch_size=batch_size, pin_memory=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, shuffle=False, batch_size=batch_size, pin_memory=True, collate_fn=collate_fn)


In [10]:
import flax
from flax import linen as nn

In [11]:
masterkey = J.random.PRNGKey(42)
masterkey

DeviceArray([ 0, 42], dtype=uint32)

In [12]:
from transformers import FlaxBertModel

class Classifier(nn.Module):
    def setup(self):
        self.bert = FlaxBertModel.from_pretrained('bert-base-cased')
        self.fc = nn.Dense(features=2)
        
    def __call__(self, x):
        out = self.bert(**x)
        
        # last layer output
        out = out.pooler_output
        # pass through dense layer
        out = self.fc(out)
        out = J.nn.log_softmax(out)
        
        return out
    
    
    
model = Classifier()

# dummy input
dummy = tokenizer.encode_plus(
            "This is some dummy text",
            add_special_tokens=True,
            max_length=512,
            truncation=True,
            return_token_type_ids=False,
            padding="max_length",
            return_attention_mask=True,
            return_tensors="np"
        )


params = model.init(masterkey, dummy)
params

Downloading:   0%|          | 0.00/413M [00:00<?, ?B/s]

Some weights of FlaxBertModel were not initialized from the model checkpoint at bert-base-cased and are newly initialized: {('pooler', 'dense', 'bias'), ('pooler', 'dense', 'kernel')}
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


FrozenDict({
    params: {
        fc: {
            kernel: DeviceArray([[-0.02475933, -0.0485843 ],
                         [-0.04283394, -0.00822861],
                         [ 0.01661486,  0.0353759 ],
                         ...,
                         [ 0.04673802,  0.0578084 ],
                         [-0.05068683, -0.01737976],
                         [ 0.05191344, -0.06784151]], dtype=float32),
            bias: DeviceArray([0., 0.], dtype=float32),
        },
    },
})