In [None]:
import tensorflow as tf
from datasets import load_dataset
from transformers import BertTokenizer, TFBertForSequenceClassification

# Load the dataset
dataset = load_dataset("nvidia/OpenMathInstruct-1")

# Print column names of the train split
print(dataset['train'].column_names)

# Use the correct column name based on inspection
text_column = 'question'  # Update this if you find the correct column name is different

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the dataset
def tokenize_function(examples):
    return tokenizer(examples[text_column], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Convert the Hugging Face dataset to TensorFlow format
def convert_to_tf_dataset(tokenized_dataset, batch_size=8):
    return tf.data.Dataset.from_tensor_slices((
        dict(tokenized_dataset.remove_columns([text_column])),
        tokenized_dataset["is_correct"]
    )).batch(batch_size)

train_dataset = convert_to_tf_dataset(tokenized_datasets["train"])
eval_dataset = convert_to_tf_dataset(tokenized_datasets["validation"])

# Load the BERT model
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')

# Compile the model
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)

# Enable GPU usage if available
if tf.config.list_physical_devices('GPU'):
    print("Using GPU")
    tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)
else:
    print("Using CPU")

# Define training arguments
epochs = 3
batch_size = 8

# Prepare the training loop
print("Setup complete. Training process is prepared but not started.")

# To start the training, uncomment the next line
# history = model.fit(train_dataset, validation_data=eval_dataset, epochs=epochs)



['question', 'expected_answer', 'predicted_answer', 'error_message', 'is_correct', 'generation_type', 'dataset', 'generated_solution']


Map:   0%|          | 0/7321344 [00:00<?, ? examples/s]