In [1]:
!pip install torch transformers datasets

Collecting datasets
  Downloading datasets-2.19.2-py3-none-any.whl (542 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.1/542.1 kB[0m [31m12.0 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.

# Build & train model in Pytorch

## Download & prepare the data

In [2]:
import torch
from torch.utils.data import DataLoader, Dataset, random_split
from transformers import AutoTokenizer
from datasets import load_dataset

# Download data from Opus Books
dataset = load_dataset("Helsinki-NLP/opus_books", "en-hu")

# Split dataset into train, validation and test
train_size = int(0.7 * len(dataset['train']))
val_size = int(0.2 * len(dataset['train']))
test_size = len(dataset['train']) - train_size - val_size

train_subset, val_subset, test_subset = random_split(dataset['train'], [train_size, val_size, test_size])

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")

class TranslationDataset(Dataset):
    def __init__(self, subset, tokenizer, max_length=32):
        self.subset = subset
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        item = self.subset[idx]['translation']
        src = item['en']
        tgt = item['hu']
        src_enc = self.tokenizer(src, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt")
        tgt_enc = self.tokenizer(tgt, truncation=True, padding='max_length', max_length=self.max_length, return_tensors="pt")
        return src_enc['input_ids'].squeeze(), tgt_enc['input_ids'].squeeze()

# Create tokenized datasets
train_dataset = TranslationDataset(train_subset, tokenizer)
val_dataset = TranslationDataset(val_subset, tokenizer)
test_dataset = TranslationDataset(test_subset, tokenizer)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Downloading readme:   0%|          | 0.00/28.1k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/23.1M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/137151 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

## Build Transformer model

In [None]:
import torch.nn as nn

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=512, nhead=8, num_encoder_layers=6, num_decoder_layers=6, dim_feedforward=2048, dropout=0.1):
        super(TransformerModel, self).__init__()

        self.transformer = nn.Transformer(d_model, nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout)
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.fc_out = nn.Linear(d_model, tgt_vocab_size)

    def forward(self, src, tgt):
        src = self.src_embedding(src)
        tgt = self.tgt_embedding(tgt)
        output = self.transformer(src, tgt)
        output = self.fc_out(output)
        return output

## Train model

In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
import torch.optim as optim

def collate_fn(batch):
    src, tgt = zip(*batch)
    return torch.stack(src), torch.stack(tgt)

train_loader = DataLoader(train_dataset, batch_size=64, collate_fn=collate_fn, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, collate_fn=collate_fn, shuffle=False)

src_vocab_size = tokenizer.vocab_size
tgt_vocab_size = tokenizer.vocab_size

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerModel(src_vocab_size, tgt_vocab_size).to(device)
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=1e-4)

def train(model, data_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for src, tgt in data_loader:
        src, tgt = src.to(device), tgt.to(device)
        optimizer.zero_grad()
        tgt_input = tgt[:-1, :]
        tgt_output = tgt[1:, :]

        output = model(src, tgt_input)
        loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(data_loader)

def evaluate(model, data_loader, criterion, device):
    model.eval()
    total_loss = 0

    with torch.no_grad():
        for src, tgt in data_loader:
            src, tgt = src.to(device), tgt.to(device)
            tgt_input = tgt[:-1, :]
            tgt_output = tgt[1:, :]

            output = model(src, tgt_input)
            loss = criterion(output.reshape(-1, output.size(-1)), tgt_output.reshape(-1))
            total_loss += loss.item()

    return total_loss / len(data_loader)

num_epochs = 3
for epoch in range(num_epochs):
    train_loss = train(model, train_loader, optimizer, criterion, device)
    val_loss = evaluate(model, val_loader, criterion, device)
    print(f'Epoch: {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')



Epoch: 1, Train Loss: 6.4356, Validation Loss: 6.2442
Epoch: 2, Train Loss: 6.1483, Validation Loss: 6.0560
Epoch: 3, Train Loss: 5.9291, Validation Loss: 5.8149


Thời gian huấn luyện: 45 phút 14 giây

## Evaluation

In [None]:
# Evaluate the model on the test data set
test_loader = DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn, shuffle=False)
test_loss = evaluate(model, test_loader, criterion, device)
print(f'Test Loss: {test_loss:.4f}')

Test Loss: 5.5370


# Apply jax and flax framework

In [3]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2023 NVIDIA Corporation
Built on Tue_Aug_15_22:02:13_PDT_2023
Cuda compilation tools, release 12.2, V12.2.140
Build cuda_12.2.r12.2/compiler.33191640_0


In [4]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-24.0-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m26.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 23.1.2
    Uninstalling pip-23.1.2:
      Successfully uninstalled pip-23.1.2
Successfully installed pip-24.0


In [5]:
!pip install flax

[0m

In [6]:
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 96005
Validation samples: 27430
Test samples: 13716


## Build the Encoder and Decoder classes

In [7]:
import jax
import jax.numpy as jnp
from flax import linen as nn
from transformers import BertTokenizerFast

class MultiHeadSelfAttention(nn.Module):
    embed_dim: int
    num_heads: int

    def setup(self):
        assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
        self.depth = self.embed_dim // self.num_heads
        self.wq = nn.Dense(self.embed_dim)
        self.wk = nn.Dense(self.embed_dim)
        self.wv = nn.Dense(self.embed_dim)
        self.dense = nn.Dense(self.embed_dim)

    def split_heads(self, x, batch_size):
        x = x.reshape(batch_size, -1, self.num_heads, self.depth)
        return x.transpose(0, 2, 1, 3)

    def __call__(self, x):
        batch_size = x.shape[0]
        q = self.split_heads(self.wq(x), batch_size)
        k = self.split_heads(self.wk(x), batch_size)
        v = self.split_heads(self.wv(x), batch_size)

        matmul_qk = jnp.matmul(q, k.transpose(0,1,3,2))

        dk = jnp.array(k.shape[-1], dtype=jnp.float32)
        scaled_attention_logits = matmul_qk / jnp.sqrt(dk)

        attention_weights = nn.softmax(scaled_attention_logits, axis=-1)

        output = jnp.matmul(attention_weights, v)
        output = output.transpose(0, 2, 1, 3)

        concat_attention = output.reshape(batch_size, -1, self.embed_dim)
        return self.dense(concat_attention)


class TransformerEncoderLayer(nn.Module):
    embed_dim: int
    num_heads: int
    dropout_rate: float = 0.1

    def setup(self):
        self.mha = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.ffn = nn.Sequential([
            nn.Dense(self.embed_dim * 4),
            nn.relu,
            nn.Dense(self.embed_dim)
        ])
        self.layernorm1 = nn.LayerNorm()
        self.layernorm2 = nn.LayerNorm()
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training):
        attn_output = self.mha(x)
        attn_output = self.dropout(attn_output, deterministic=not training)
        out1 = self.layernorm1(x + attn_output)

        ffn_output = self.ffn(out1)
        ffn_output = self.dropout(ffn_output, deterministic=not training)
        return self.layernorm2(out1 + ffn_output)


class TransformerDecoderLayer(nn.Module):
    embed_dim: int
    num_heads: int
    dropout_rate: float = 0.1

    def setup(self):
        self.mha1 = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.mha2 = MultiHeadSelfAttention(self.embed_dim, self.num_heads)
        self.ffn = nn.Sequential([
            nn.Dense(self.embed_dim * 4),
            nn.relu,
            nn.Dense(self.embed_dim)
        ])
        self.layernorm1 = nn.LayerNorm()
        self.layernorm2 = nn.LayerNorm()
        self.layernorm3 = nn.LayerNorm()
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, enc_output, training):
        attn1 = self.mha1(x)
        attn1 = self.dropout(attn1, deterministic=not training)
        out1 = self.layernorm1(x + attn1)

        attn2 = self.mha2(out1)
        attn2 = self.dropout(attn2, deterministic=not training)
        out2 = self.layernorm2(out1 + attn2)

        ffn_output = self.ffn(out2)
        ffn_output = self.dropout(ffn_output, deterministic=not training)
        return self.layernorm3(out2 + ffn_output)

## Build Transformer model

In [8]:
class Transformer(nn.Module):
    vocab_size: int
    num_heads: int
    num_layers: int
    hidden_dim: int
    dropout_rate: float

    def setup(self):
        self.token_embedding = nn.Embed(self.vocab_size, self.hidden_dim)
        self.position_embedding = nn.Embed(512, self.hidden_dim)
        self.encoder_layers = [
            TransformerEncoderLayer(self.hidden_dim, self.num_heads, self.dropout_rate) for _ in range(self.num_layers)
        ]
        self.decoder_layers = [
            TransformerDecoderLayer(self.hidden_dim, self.num_heads, self.dropout_rate) for _ in range(self.num_layers)
        ]
        self.output_dense = nn.Dense(self.vocab_size)

    def encode(self, x, training):
        x = self.token_embedding(x) + self.position_embedding(jnp.arange(x.shape[1]))
        for layer in self.encoder_layers:
            x = layer(x, training)
        return x

    def decode(self, x, enc_output, training):
        x = self.token_embedding(x) + self.position_embedding(jnp.arange(x.shape[1]))
        for layer in self.decoder_layers:
            x = layer(x, enc_output, training)
        return self.output_dense(x)

    def __call__(self, src, tgt, training):
        enc_output = self.encode(src, training)
        dec_output = self.decode(tgt, enc_output, training)
        return dec_output

## Train and evaluate model

In [9]:
# Initialize tokenizers
tokenizer_src = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")
tokenizer_tgt = BertTokenizerFast.from_pretrained("bert-base-multilingual-cased")

# Initialize DataLoaders from dataset
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

# Hyperparameters
vocab_size = tokenizer_src.vocab_size
hidden_dim = 512
num_heads = 8
num_layers = 6
dropout_rate = 0.1
learning_rate = 0.001

model = Transformer(vocab_size=vocab_size, num_heads=num_heads, num_layers=num_layers, hidden_dim=hidden_dim, dropout_rate=dropout_rate)

params = model.init(jax.random.PRNGKey(0), jnp.ones([2, 512], jnp.int32), jnp.ones([2, 512], jnp.int32), True)['params']

# Loss function
def loss_fn(params, src, tgt, labels, rng, training=True):
    dropout_rng, new_dropout_rng = jax.random.split(rng)
    logits = model.apply({'params': params}, src, tgt, training=training, rngs={'dropout': dropout_rng})

    # Reshape logits và labels
    batch_size, seq_len, vocab_size = logits.shape
    logits = logits.reshape(batch_size * seq_len, vocab_size)
    labels = labels.reshape(batch_size * seq_len)

    # Calculate cross-entropy loss
    labels_onehot = jax.nn.one_hot(labels, vocab_size)
    loss = jnp.mean(-jnp.sum(labels_onehot * jax.nn.log_softmax(logits), axis=-1))

    return loss, new_dropout_rng

# Optimizer update function manually
def sgd_update(params, grads, learning_rate):
    return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Training step
@jax.jit
def train_step(params, src, tgt, labels, rng):
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, new_rng), grads = grad_fn(params, src, tgt, labels, rng, training=True)
    params = sgd_update(params, grads, learning_rate)
    return params, loss, new_rng

# Evaluation
@jax.jit
def eval_step(params, src, tgt, labels, rng):
    dropout_rng, new_dropout_rng = jax.random.split(rng)
    logits = model.apply({'params': params}, src, tgt, training=False, rngs={'dropout': dropout_rng})

    # Reshape logits và labels
    batch_size, seq_len, vocab_size = logits.shape
    logits = logits.reshape(batch_size * seq_len, vocab_size)
    labels = labels.reshape(batch_size * seq_len)

    # Calculate cross-entropy loss
    labels_onehot = jax.nn.one_hot(labels, vocab_size)
    loss = jnp.mean(-jnp.sum(labels_onehot * jax.nn.log_softmax(logits), axis=-1))

    return loss, new_dropout_rng

# Training loop
num_epochs = 3
rng = jax.random.PRNGKey(0)

for epoch in range(num_epochs):
    train_losses = []
    for src_batch, tgt_batch in train_loader:
        src = src_batch.squeeze().detach().cpu().numpy()
        tgt = tgt_batch.squeeze().detach().cpu().numpy()
        params, loss, rng = train_step(params, src, tgt, tgt, rng)
        train_losses.append(loss)

    val_losses = []
    for src_batch, tgt_batch in val_loader:
        src = src_batch.squeeze().detach().cpu().numpy()
        tgt = tgt_batch.squeeze().detach().cpu().numpy()
        val_loss, rng = eval_step(params, src, tgt, tgt, rng)
        val_losses.append(val_loss)

    print(f'Epoch {epoch+1}, Loss: {jnp.mean(jnp.array(train_losses))}, Val Loss: {jnp.mean(jnp.array(val_losses))}')

  return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)
  return jax.tree_map(lambda p, g: p - learning_rate * g, params, grads)


Epoch 1, Loss: 8.339764595031738, Val Loss: 7.890867233276367
Epoch 2, Loss: 7.69901704788208, Val Loss: 7.447364807128906
Epoch 3, Loss: 7.220347881317139, Val Loss: 6.930772304534912


Thời gian chạy: 29 phút

# Run JAX code on multiple devices using pmap

## Benchmark the speed-up

### Single-device benchmark

In [10]:
import time

# Benchmarking single-device training
start_time = time.time()

for src_batch, tgt_batch in train_loader:
    src = src_batch.squeeze().detach().cpu().numpy()
    tgt = tgt_batch.squeeze().detach().cpu().numpy()
    params, loss, rng = train_step(params, src, tgt, tgt, rng)

single_device_duration = time.time() - start_time
print(f"Single-device training duration: {single_device_duration} seconds")

Single-device training duration: 517.1738803386688 seconds


### Multi-device benchmark

## Compute the speed-up