# Polish Conversational AI Training on TPU v5e-1

## Overview
This notebook trains a conversational AI model on the Polish sentences dataset (94M+ fragments) using Google Colab TPU v5e-1.

**Dataset:** [adowu/polish_sentences](https://huggingface.co/datasets/adowu/polish_sentences)
- Total rows: 94,167,155
- Size: ~2.6 GB (original), 1.01 GB (Parquet)
- Content: Polish text fragments ranging from 3 characters to 7.51k

**Hardware:** TPU v5e-1 (optimized for training efficiency)

**Framework:** JAX/Flax (optimal for TPU) with Hugging Face Transformers

## 1. Environment Setup & TPU Configuration

In [None]:
!pip install -q --upgrade pip
!pip install -q datasets transformers jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html flax optax huggingface-hub sentencepiece
print('✓ Packages installed successfully')

In [None]:
import jax
import jax.numpy as jnp
from jax import random
devices = jax.devices()
print(f'Available devices: {devices}')
print(f'Device count: {jax.device_count()}')
assert jax.device_count() >= 1, 'No TPU devices found!'
print('✓ TPU configured successfully')

## 2. Dataset Loading & Analysis

In [None]:
from datasets import load_dataset
import numpy as np
dataset = load_dataset('adowu/polish_sentences', split='train')
print(f'Dataset loaded: {len(dataset):,} rows')
sample = dataset.shuffle(seed=42).select(range(min(10000, len(dataset))))
lengths = [len(item['fragment']) for item in sample]
print(f'Mean length: {np.mean(lengths):.1f} chars')
print(f'Median length: {np.median(lengths):.1f} chars')
print('✓ Dataset analysis complete')

## 3. Tokenization & Preprocessing

In [None]:
from transformers import AutoTokenizer
MODEL_CHECKPOINT = 'sdadas/polish-gpt2-medium'
tokenizer = AutoTokenizer.from_pretrained(MODEL_CHECKPOINT)
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': '<|pad|>', 'bos_token': '<|startoftext|>', 'eos_token': '<|endoftext|>'})
print(f'Tokenizer vocab size: {len(tokenizer)}')
print('✓ Tokenizer configured')

In [None]:
MAX_LENGTH = 512
def preprocess_function(examples):
    texts = [f"{tokenizer.bos_token}{text}{tokenizer.eos_token}" for text in examples['fragment']]
    tokenized = tokenizer(texts, truncation=True, max_length=MAX_LENGTH, padding='max_length', return_tensors=None)
    tokenized['labels'] = tokenized['input_ids'].copy()
    return tokenized
tokenized_dataset = dataset.map(preprocess_function, batched=True, batch_size=1000, remove_columns=dataset.column_names, num_proc=2)
split = tokenized_dataset.train_test_split(test_size=0.01, seed=42)
train_dataset = split['train']
eval_dataset = split['test']
print(f'Train: {len(train_dataset):,}, Eval: {len(eval_dataset):,}')
print('✓ Dataset tokenized')

## 4. Model Initialization

In [None]:
from transformers import FlaxAutoModelForCausalLM, AutoConfig
config = AutoConfig.from_pretrained(MODEL_CHECKPOINT)
config.vocab_size = len(tokenizer)
model = FlaxAutoModelForCausalLM.from_pretrained(MODEL_CHECKPOINT, config=config, dtype=jnp.bfloat16, _do_init=True)
num_params = sum(x.size for x in jax.tree_util.tree_leaves(model.params))
print(f'Model parameters: {num_params:,} ({num_params/1e6:.1f}M)')
print('✓ Model initialized')

## 5. Training Configuration

In [None]:
import optax
from flax.training import train_state
LEARNING_RATE = 5e-5
WARMUP_STEPS = 2000
BATCH_SIZE = 32
NUM_EPOCHS = 3
steps_per_epoch = len(train_dataset) // BATCH_SIZE
total_steps = steps_per_epoch * NUM_EPOCHS
print(f'Steps per epoch: {steps_per_epoch:,}')
print(f'Total steps: {total_steps:,}')
print('✓ Training config set')

In [None]:
warmup_fn = optax.linear_schedule(0.0, LEARNING_RATE, WARMUP_STEPS)
decay_fn = optax.cosine_decay_schedule(LEARNING_RATE, total_steps - WARMUP_STEPS, 0.1)
lr_schedule = optax.join_schedules([warmup_fn, decay_fn], [WARMUP_STEPS])
optimizer = optax.chain(optax.clip_by_global_norm(1.0), optax.adamw(lr_schedule, b1=0.9, b2=0.999, weight_decay=0.01))
class TrainState(train_state.TrainState):
    dropout_rng: jnp.ndarray
rng = random.PRNGKey(42)
rng, dropout_rng = random.split(rng)
state = TrainState.create(apply_fn=model.__call__, params=model.params, tx=optimizer, dropout_rng=dropout_rng)
print('✓ Optimizer created')

## 6. Training Loop

In [None]:
from functools import partial
from flax import jax_utils
@partial(jax.pmap, axis_name='batch', donate_argnums=(0,))
def train_step(state, batch):
    dropout_rng, new_dropout_rng = random.split(state.dropout_rng)
    def loss_fn(params):
        labels = batch['labels']
        outputs = state.apply_fn(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], params=params, dropout_rng=dropout_rng, train=True)
        logits = outputs.logits
        vocab_size = logits.shape[-1]
        labels_one_hot = jax.nn.one_hot(labels, vocab_size)
        loss = optax.softmax_cross_entropy(logits, labels_one_hot)
        mask = (labels != tokenizer.pad_token_id).astype(jnp.float32)
        return (loss * mask).sum() / mask.sum()
    grad_fn = jax.value_and_grad(loss_fn)
    loss, grads = grad_fn(state.params)
    grads = jax.lax.pmean(grads, axis_name='batch')
    loss = jax.lax.pmean(loss, axis_name='batch')
    new_state = state.apply_gradients(grads=grads, dropout_rng=new_dropout_rng)
    return new_state, {'loss': loss, 'learning_rate': lr_schedule(state.step)}
@partial(jax.pmap, axis_name='batch')
def eval_step(state, batch):
    labels = batch['labels']
    outputs = state.apply_fn(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], params=state.params, train=False)
    logits = outputs.logits
    vocab_size = logits.shape[-1]
    labels_one_hot = jax.nn.one_hot(labels, vocab_size)
    loss = optax.softmax_cross_entropy(logits, labels_one_hot)
    mask = (labels != tokenizer.pad_token_id).astype(jnp.float32)
    loss = (loss * mask).sum() / mask.sum()
    perplexity = jnp.exp(loss)
    return {'loss': jax.lax.pmean(loss, axis_name='batch'), 'perplexity': jax.lax.pmean(perplexity, axis_name='batch')}
print('✓ Training functions defined')

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
def create_data_loader(dataset, batch_size, shuffle=False):
    def _data_generator():
        ds = dataset.shuffle(buffer_size=10000, seed=42) if shuffle else dataset
        for example in ds:
            yield {'input_ids': example['input_ids'], 'attention_mask': example['attention_mask'], 'labels': example['labels']}
    output_sig = {'input_ids': tf.TensorSpec(shape=(MAX_LENGTH,), dtype=tf.int32), 'attention_mask': tf.TensorSpec(shape=(MAX_LENGTH,), dtype=tf.int32), 'labels': tf.TensorSpec(shape=(MAX_LENGTH,), dtype=tf.int32)}
    tf_dataset = tf.data.Dataset.from_generator(_data_generator, output_signature=output_sig)
    return tf_dataset.batch(batch_size, drop_remainder=True).prefetch(tf.data.AUTOTUNE).as_numpy_iterator()
print('✓ Data loader created')

In [None]:
from tqdm.auto import tqdm
import time
from flax.training.common_utils import shard
state = jax_utils.replicate(state)
print('=== Starting Training ===')
global_step = 0
start_time = time.time()
for epoch in range(NUM_EPOCHS):
    print(f'\nEpoch {epoch + 1}/{NUM_EPOCHS}')
    train_loader = create_data_loader(train_dataset, BATCH_SIZE, shuffle=True)
    epoch_loss = 0
    epoch_steps = 0
    progress_bar = tqdm(total=steps_per_epoch, desc=f'Epoch {epoch+1}')
    for batch in train_loader:
        batch = {k: jnp.array(v) for k, v in batch.items()}
        batch = shard(batch)
        state, metrics = train_step(state, batch)
        loss = jax_utils.unreplicate(metrics['loss'])
        epoch_loss += loss
        epoch_steps += 1
        global_step += 1
        if global_step % 100 == 0:
            progress_bar.set_postfix({'loss': f'{epoch_loss / epoch_steps:.4f}'})
        progress_bar.update(1)
        if epoch_steps >= steps_per_epoch:
            break
    progress_bar.close()
    print(f'Epoch {epoch+1} avg loss: {epoch_loss / epoch_steps:.4f}')
total_time = time.time() - start_time
print(f'\nTraining complete! Time: {total_time/3600:.2f}h')
print('✓ Training finished')

## 7. Save Model

In [None]:
final_model_dir = '/content/polish_conversational_model'
unreplicated_state = jax_utils.unreplicate(state)
model.save_pretrained(final_model_dir, params=unreplicated_state.params)
tokenizer.save_pretrained(final_model_dir)
config.save_pretrained(final_model_dir)
print(f'✓ Model saved to {final_model_dir}')

## 8. Test Generation

In [None]:
def generate_text(prompt, max_length=100):
    inputs = tokenizer(prompt, return_tensors='np', padding=True)
    params = jax_utils.unreplicate(state.params)
    rng = random.PRNGKey(int(time.time()))
    generated = model.generate(jnp.array(inputs['input_ids']), attention_mask=jnp.array(inputs['attention_mask']), params=params, max_length=max_length, do_sample=True, temperature=0.7, top_p=0.9, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, prng_key=rng).sequences
    return tokenizer.decode(generated[0], skip_special_tokens=True)
test_prompts = ['Bóg pobłogosławi', 'Witaj, jak się', 'Dzisiaj jest']
print('=== Test Generation ===')
for prompt in test_prompts:
    print(f'\nPrompt: {prompt}')
    print(f'Generated: {generate_text(prompt, 50)}')
print('\n✓ Generation test complete')

## Summary

This notebook provides a complete pipeline for training Polish conversational AI on TPU v5e-1:

1. **Dataset**: 94M+ Polish sentence fragments
2. **Model**: GPT-2 fine-tuned for Polish
3. **Hardware**: TPU v5e-1 with JAX/Flax
4. **Training**: 3 epochs with AdamW
5. **Output**: Production-ready model

### Resources
- [Dataset](https://huggingface.co/datasets/adowu/polish_sentences)
- [JAX Docs](https://jax.readthedocs.io/)
- [Flax Docs](https://flax.readthedocs.io/)