In [1]:
import random
import numpy as np
from transformers import T5Tokenizer, FlaxT5ForConditionalGeneration
import jax
import jax.numpy as jnp
import optax

  from .autonotebook import tqdm as notebook_tqdm
2023-04-27 06:46:51.835528: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-04-27 06:46:51.837998: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-04-27 06:46:51.900062: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2023-04-27 06:46:51.901602: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
with open('Y.txt', 'r') as f:
    Y = f.read().splitlines()
with open('X.txt', 'r') as f:
    X = f.read().splitlines()

In [3]:
# 准备数据
input_texts = X  # 输入字符串列表（X）
target_texts = Y  # 目标字符串列表（Y）

In [4]:
tokenizer = T5Tokenizer.from_pretrained('t5-small')

In [5]:
def tokenize_function(input_texts, target_texts):
    inputs = tokenizer(input_texts, padding="max_length", truncation=True, max_length=128, return_tensors="np")
    targets = tokenizer(target_texts, padding="max_length", truncation=True, max_length=128, return_tensors="np")
    return inputs, targets

In [6]:
# 使用列表推导式来对数据进行tokenize
tokenized_data = [tokenize_function(input_text, target_text) for input_text, target_text in zip(input_texts, target_texts)]

In [7]:
# 定义模型
model = FlaxT5ForConditionalGeneration.from_pretrained('t5-small')

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [8]:
# 定义超参数
learning_rate = 0.001
batch_size = 8
num_epochs = 3
num_train_steps = len(input_texts) * num_epochs // batch_size

In [9]:
# 创建模型及优化器
tokenizer = T5Tokenizer.from_pretrained("t5-small", max_length=1024)
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small", max_length=1024)

In [10]:
model = FlaxT5ForConditionalGeneration.from_pretrained("t5-small")
optimizer = optax.adamw(learning_rate)

In [11]:
# 初始化优化器状态
params = model.params
optimizer_state = optimizer.init(params)

In [12]:
X = np.array(X)
Y = np.array(Y)

In [13]:
# 分割数据集
input_texts = [str(x) for x in X.tolist()]
target_texts = [str(y) for y in Y.tolist()]

In [14]:
len(input_texts)

6584

In [15]:
tokenized_data = [(tokenizer(input_text, return_tensors="np"), tokenizer(target_text, return_tensors="np")) for input_text, target_text in zip(input_texts, target_texts)]

In [16]:
# 训练参数
num_epochs = 10
batch_size = 8

In [17]:
# 训练循环
@jax.jit
def train_step(params, optimizer_state, inputs, targets):
    def loss_fn(params):
        with jax.disable_jit():
            logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], decoder_input_ids=targets, params=params).logits
        loss = jnp.mean(jax.nn.sparse_cross_entropy(logits=logits, labels=targets))
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(params)
    updates, new_optimizer_state = optimizer.update(grad, optimizer_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_optimizer_state, loss

In [18]:
# 训练模型
for epoch in range(num_epochs):
    for batch_idx in range(len(input_texts) // batch_size):
        batch_input_data = [tokenized_data[i][0] for i in range(batch_idx * batch_size, (batch_idx + 1) * batch_size)]
        batch_target_data = [tokenized_data[i][1] for i in range(batch_idx * batch_size, (batch_idx + 1) * batch_size)]
        
        batch_inputs = tokenizer.pad(batch_input_data, return_tensors="np")
        batch_targets = tokenizer.pad(batch_target_data, return_tensors="np")
        
        model.params, optimizer_state, train_loss = train_step(
            model.params, optimizer_state,
            {k: jnp.array(v) for k, v in batch_inputs.items()},
            jnp.array(batch_targets['input_ids'])
        )
        print(f"Epoch {epoch + 1}/{num_epochs} | Batch {batch_idx + 1}/{len(input_texts) // batch_size} | Loss: {train_loss:.4f}")

ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).

In [20]:
tokenized_data = tokenizer.batch_encode_plus(list(zip(input_texts, target_texts)), padding=True, truncation=True)

In [25]:
for epoch in range(num_epochs):
    for batch_idx in range(len(input_texts) // batch_size):
        batch_encoding = tokenizer(batch_input_data, batch_target_data, padding=True, truncation=True, return_tensors='jax')
        batch_inputs = jax.tree_map(lambda x: x[:, :max_seq_length], batch_encoding['input_ids'])
        batch_targets = jax.tree_map(lambda x: x[:, :max_seq_length], batch_encoding['attention_mask'])
        
        model.params, optimizer_state, train_loss = train_step(
            model.params, optimizer_state,
            {k: jnp.array(v) for k, v in batch_inputs.items()},
            jnp.array(batch_targets['input_ids'])
        )
        print(f"Epoch {epoch + 1}/{num_epochs} | Batch {batch_idx + 1}/{len(input_texts) // batch_size} | Loss: {train_loss:.4f}")

ValueError: text input must of type `str` (single example), `List[str]` (batch or single pretokenized example) or `List[List[str]]` (batch of pretokenized examples).

In [26]:
input_texts = [str(x) for x in X_filtered]
target_texts = [str(y) for y in Y_filtered]

In [28]:
tokenized_data = [(tokenizer(input_text, return_tensors="np"), tokenizer(target_text, return_tensors="np")) for input_text, target_text in zip(input_texts, target_texts)]

In [29]:
# 训练参数
num_epochs = 10
batch_size = 8

In [30]:
# 训练循环
@jax.jit
def train_step(params, optimizer_state, inputs, targets):
    def loss_fn(params):
        with jax.disable_jit():
            logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], decoder_input_ids=targets, params=params).logits
        loss = jnp.mean(jax.nn.sparse_cross_entropy(logits=logits, labels=targets))
        return loss

    grad_fn = jax.value_and_grad(loss_fn)
    loss, grad = grad_fn(params)
    updates, new_optimizer_state = optimizer.update(grad, optimizer_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_optimizer_state, loss

In [34]:
def tokenize_function(input_texts, target_texts, max_length=512):
    inputs = tokenizer(input_texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="np")
    targets = tokenizer(target_texts, padding="max_length", truncation=True, max_length=max_length, return_tensors="np")
    return inputs, targets

In [37]:
def pad_and_truncate(inputs, max_length):
    # 将输入填充或截断到指定的最大长度
    padded_inputs = []
    for seq in inputs:
        if len(seq) > max_length:
            padded_seq = seq[:max_length]
        else:
            padded_seq = seq + [0] * (max_length - len(seq))
        padded_inputs.append(padded_seq)

    # 转换为numpy数组
    padded_inputs = np.array(padded_inputs)

    # 创建一个掩码以标识填充的位置
    mask = np.zeros_like(padded_inputs)
    mask[padded_inputs != 0] = 1

    return padded_inputs, mask

def tokenize_function(input_texts, target_texts, max_length=512):
    # 分别对输入和目标进行编码
    inputs = tokenizer.batch_encode_plus(
        input_texts,
        max_length=max_length,
        padding=False,
        truncation=True,
        return_attention_mask=False,
        return_token_type_ids=False,
        return_tensors="np"
    )["input_ids"]

    targets = tokenizer.batch_encode_plus(
        target_texts,
        max_length=max_length,
        padding=False,
        truncation=True,
        return_attention_mask=False,
        return_token_type_ids=False,
        return_tensors="np"
    )["input_ids"]

    # 对输入和目标进行填充和截断
    padded_inputs, input_mask = pad_and_truncate(inputs, max_length)
    padded_targets, target_mask = pad_and_truncate(targets, max_length)

    # 将填充和截断后的输入和目标转换为字典格式
    input_dict = {"input_ids": padded_inputs, "attention_mask": input_mask}
    target_dict = {"input_ids": padded_targets, "attention_mask": target_mask}

    return input_dict, target_dict

In [39]:
for batch_idx in range(len(input_texts) // batch_size):
    batch_input_texts = input_texts[batch_idx * batch_size:(batch_idx + 1) * batch_size]
    batch_target_texts = target_texts[batch_idx * batch_size:(batch_idx + 1) * batch_size]

    batch_inputs, batch_targets = tokenize_function(batch_input_texts, batch_target_texts, max_length=max_length)

    model.params, optimizer_state, train_loss = train_step(
        model.params, optimizer_state,
        {k: jnp.array(v) for k, v in batch_inputs.items()},
        jnp.array(batch_targets['input_ids'])
    )
    print(f"Epoch {epoch + 1}/{num_epochs} | Batch {batch_idx + 1}/{len(input_texts) // batch_size} | Loss: {train_loss:.4f}")


ValueError: Unable to create tensor, you should probably activate truncation and/or padding with 'padding=True' 'truncation=True' to have batched tensors with the same length. Perhaps your features (`input_ids` in this case) have excessive nesting (inputs type `list` where type `int` is expected).