In [1]:
!pip install jax-dataloader -q

[0m

In [2]:
!pip install -U "jax[cuda12]" -q

[0m

In [3]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from jax import jit
import optax
from transformers import MarianTokenizer, MarianMTModel, FlaxMarianMTModel
from datasets import load_dataset, load_metric, DatasetDict
from flax.training import train_state
import flax
from jax import random
from functools import partial
import time
from typing import Callable
import jax_dataloader as jdl
from tqdm import tqdm
from time import time
import numpy as np
from flax.training.common_utils import get_metrics, onehot, shard, shard_prng_key

2024-06-03 07:00:07.795434: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-06-03 07:00:07.795597: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-06-03 07:00:07.915651: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [4]:
print('JAX is running on', jax.lib.xla_bridge.get_backend().platform)

JAX is running on gpu


In [5]:
jax.local_devices()

[cuda(id=0)]

In [6]:
def get_config():
    config = {
        'model_name': 'Helsinki-NLP/opus-mt-en-hu', # 't5-small', 
        'max_length': 64,
        'batch_size': 64,
        'lr': 10 ** -5,
        'epochs': 10,
        'seed': 42,
        'metric_name': 'sacrebleu',
        'save_model': '/kaggle/working/model.pth',
        'per_device_batch_size': 64
        
    }
    return config

In [7]:
config = get_config()

In [8]:
# Load dataset and metric
dataset = load_dataset('Helsinki-NLP/opus_books', 'en-hu')
val_test_set = dataset['train'].train_test_split(test_size=0.2, seed=42)
test_set = val_test_set['test'].train_test_split(test_size=0.5, seed=42)

dataset = DatasetDict({
    'train': val_test_set['train'],
    'val': test_set['test'],
    'test': test_set['train']
})

Downloading readme:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/137151 [00:00<?, ? examples/s]

In [9]:
tokenizer = MarianTokenizer.from_pretrained(config['model_name'])
# Load the model
model = FlaxMarianMTModel.from_pretrained(config['model_name'], from_pt=True)

tokenizer_config.json:   0%|          | 0.00/42.0 [00:00<?, ?B/s]

source.spm:   0%|          | 0.00/792k [00:00<?, ?B/s]

target.spm:   0%|          | 0.00/850k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.57M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/1.38k [00:00<?, ?B/s]



pytorch_model.bin:   0%|          | 0.00/307M [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()
Some weights of the model checkpoint at Helsinki-NLP/opus-mt-en-hu were not used when initializing FlaxMarianMTModel: {('model', 'encoder', 'embed_positions', 'kernel'), ('model', 'decoder', 'embed_tokens', 'kernel'), ('model', 'encoder', 'embed_tokens', 'kernel'), ('model', 'decoder', 'embed_positions', 'kernel')}
- This IS expected if you are initializing FlaxMarianMTModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing FlaxMarianMTModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


generation_config.json:   0%|          | 0.00/293 [00:00<?, ?B/s]

In [10]:
def preprocess_fn(dataset: DatasetDict):
    inputs = [ex['en'] for ex in dataset['translation']]
    targets = [ex['hu'] for ex in dataset['translation']]
    model_inputs = tokenizer(inputs, max_length=config['max_length'], truncation=True, padding='longest')

    # tokenizer targets 
    #with tokenizer.as_target_tokenizer():
    labels =  tokenizer(targets, max_length=config['max_length'], truncation=True, padding='longest')
        
    model_inputs['labels'] = labels.input_ids

    return model_inputs

In [11]:
tokenized_datasets = dataset.map(preprocess_fn, batched=True, remove_columns=dataset["train"].column_names)

Map:   0%|          | 0/109720 [00:00<?, ? examples/s]

Map:   0%|          | 0/13716 [00:00<?, ? examples/s]

Map:   0%|          | 0/13715 [00:00<?, ? examples/s]

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 109720
    })
    val: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 13716
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 13715
    })
})

### Dataloader

In [13]:
train_loader = jdl.DataLoader(tokenized_datasets['train'], 'jax', batch_size=config['batch_size'], shuffle=True, )
val_loader = jdl.DataLoader(tokenized_datasets['val'], 'jax', batch_size=config['batch_size'], shuffle=False)
test_loader = jdl.DataLoader(tokenized_datasets['test'], 'jax', batch_size=config['batch_size'], shuffle=False)

In [14]:
for i in train_loader:
    print(type(i['input_ids'][0]))
    break

<class 'numpy.ndarray'>


### Tokenizer and model

In [15]:
total_batch_size = config['per_device_batch_size'] * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

The overall batch size (both for training and eval) is 64


In [16]:
num_train_steps = len(dataset['train']) // total_batch_size * config['epochs']

learning_rate_function = optax.linear_schedule(init_value=config['lr'], end_value=0, transition_steps=num_train_steps)

In [17]:
class TrainState(train_state.TrainState):
    loss_function: Callable = flax.struct.field(pytree_node=False)

In [18]:
def loss_function(logits, labels):
    padding_mask = (labels != tokenizer.pad_token_id)
 
    # One-hot encode the labels
    one_hot_labels = jax.nn.one_hot(labels, num_classes=logits.shape[-1])
    
    # Compute the cross-entropy loss
    cross_entropy_loss = optax.softmax_cross_entropy(logits, one_hot_labels)

    # Apply the padding mask to the loss
    loss = jnp.sum(cross_entropy_loss * padding_mask, axis=-1) / jnp.sum(padding_mask, axis=-1)
    return loss.mean()

In [19]:
state = TrainState.create(
    apply_fn=model.__call__,
    params=model.params,
    tx=optax.adamw(learning_rate=learning_rate_function),
    loss_function=loss_function,
)

In [20]:
@jit
def train_step(state, batch, dropout_rng):
    dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
    
    def compute_loss(params):
        logits = state.apply_fn(
            input_ids=batch['input_ids'], 
            attention_mask=batch['attention_mask'], 
            params=params, 
            dropout_rng=dropout_rng, 
            train=True
        ).logits
        loss = state.loss_function(logits, batch['labels'])
        return loss
    
    loss, grads = jax.value_and_grad(compute_loss)(state.params)
    state = state.apply_gradients(grads=grads)
    return state, loss, new_dropout_rng

In [21]:
@jit
def eval_step(state, batch):
    outputs = state.apply_fn(
        input_ids=batch['input_ids'], 
        attention_mask=batch['attention_mask'], 
        params=state.params, 
        train=False
    )
    logits = outputs[0]
    loss = state.loss_function(logits, batch['labels'])
    return loss

# Training

In [22]:
# Training loop
rng = jax.random.PRNGKey(0)
dropout_rngs = rng

In [23]:
total_batch_size = config['per_device_batch_size'] * jax.local_device_count()
print("The overall batch size (both for training and eval) is", total_batch_size)

start_time = time()
print('======== Start training ======== ')
for epoch in range(config['epochs']):
    training_loss = 0
    with tqdm(total=len(train_loader), desc="Training...", leave=False) as progress_bar_train:
        for batch in train_loader: 
            state, loss, dropout_rngs = train_step(state, batch, dropout_rngs)
            training_loss += jax.device_get(loss)
            progress_bar_train.update(1)
        training_loss = training_loss / len(train_loader)
        
    eval_loss = 0
    with tqdm(total=len(val_loader), desc="Evaluating...", leave=False) as progress_bar_eval:
        for batch in val_loader: 
            loss = eval_step(state, batch)
            eval_loss += jax.device_get(loss)
            progress_bar_eval.update(1)
        eval_loss = eval_loss / len(val_loader)
    
    print(f"Epoch {epoch + 1}: Training loss = {training_loss}, Val loss = {eval_loss}")
    
print('======== End training ========')
total_training_time = time() - start_time
print(f"Total training time: {total_training_time}")

The overall batch size (both for training and eval) is 64


                                                                

Epoch 1: Training loss = 5.285857452804076, Val loss = 5.023292985073356


                                                                

Epoch 2: Training loss = 5.003126142740944, Val loss = 4.915793815878935


                                                                

Epoch 3: Training loss = 4.908296616501433, Val loss = 4.854043594626493


                                                                

Epoch 4: Training loss = 4.84413925593518, Val loss = 4.817263567724893


                                                                

Epoch 5: Training loss = 4.802144354216907, Val loss = 4.794698072034259


                                                                

Epoch 6: Training loss = 4.771681739707035, Val loss = 4.777872708786366


                                                                

Epoch 7: Training loss = 4.749162184045197, Val loss = 4.766490097933037


                                                                

Epoch 8: Training loss = 4.732483908316832, Val loss = 4.762231349945068


                                                                

Epoch 9: Training loss = 4.720306410664372, Val loss = 4.754921088107797


                                                                

Epoch 10: Training loss = 4.713775796306376, Val loss = 4.754215644126715
Total training time: 5586.224370002747


