<a href="https://colab.research.google.com/github/ayami-n/Flax_text_prediction/blob/main/Main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/Flax_text_prediction

/content/drive/MyDrive/Flax_text_prediction


# Import libs

In [None]:
%%capture
!pip install datasets
!pip install git+https://github.com/huggingface/transformers.git
!pip install flax
!pip install git+https://github.com/deepmind/optax.git

In [None]:
import jax
from jax import random  # to create random values for initalizing a model (Flax requires)
import jax.numpy as jnp

# Flax for building model
try:
    import flax
except ModuleNotFoundError: # Install flax if missing
    !pip install --quiet flax
    import flax

from flax import linen as nn
from flax.training import train_state, checkpoints

# Optax for optimizor 
import optax

# Transformers
!pip install transformers
from transformers import FlaxAutoModelForSequenceClassification, AutoConfig
from transformers import RobertaTokenizer, RobertaConfig # as we use Roberta model

# others
import pandas as pd

# Config

In [None]:
model_checkpoint = 'siebert/sentiment-roberta-large-english'
num_labels = 3 
seed = 0
max_len = 128

# Tokenazation

In [None]:
df = pd.read_csv("./kaggle/train.csv")  # import train datasets

In [None]:
def bert_encode(texts, tokenizer, max_len):
    input_ids = []
    token_type_ids = []
    attention_mask = []
    
    for text in texts:
        token = tokenizer(text, max_length=max_len, truncation=True, padding='max_length',
                         add_special_tokens=True,  return_tensors='jax')
        input_ids.append(token['input_ids'])
        attention_mask.append(token['attention_mask'])
    
    return jnp.array(input_ids), jnp.array(token_type_ids), jnp.array(attention_mask)

# Create a model

In [14]:
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels)
model = FlaxAutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, seed=seed, ignore_mismatched_sizes=True)  # ignore_mismatched_sizes=True: arrow to have arbitary number of outputs

Some weights of FlaxRobertaForSequenceClassification were not initialized from the model checkpoint at siebert/sentiment-roberta-large-english and are newly initialized because the shapes did not match:
- ('classifier', 'out_proj', 'bias'): found shape (2,) in the checkpoint and (3,) in the model instantiated
- ('classifier', 'out_proj', 'kernel'): found shape (1024, 2) in the checkpoint and (1024, 3) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [6]:
tokenizer = RobertaTokenizer.from_pretrained('siebert/sentiment-roberta-large-english')
inputs = tokenizer("I love you.", max_length=128, truncation=True, padding='max_length',
                         add_special_tokens=True,  return_tensors='jax')

Downloading:   0%|          | 0.00/780k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/150 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/256 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/687 [00:00<?, ?B/s]

In [10]:
print(inputs)

{'input_ids': DeviceArray([[  0, 100, 657,  47,   4,   2,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
                1,   1,   1,   1,   1,   1,   1,   1]], dtype=int32), 'attention_mask': DeviceArray([[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
              0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
          

In [17]:
out = model(**inputs)
print(out)
print(out.logits)

FlaxSequenceClassifierOutput(logits=DeviceArray([[-0.4148041 , -0.48419115,  0.02517768]], dtype=float32), hidden_states=None, attentions=None)
[[-0.4148041  -0.48419115  0.02517768]]
