In [None]:
import torch
from datasets import load_dataset
from transformers import (
    BertTokenizer, 
    BertForSequenceClassification, 
    TrainingArguments, 
    Trainer
)

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load pretrained model and tokenizer
model_name = "bert-base-uncased"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.to(device)  # Move model to appropriate device

# Load dataset
dataset = load_dataset("imdb", split="train[:2000]")
dataset = dataset.train_test_split(test_size=0.2)
print(f"Train dataset size: {len(dataset['train'])}")
print(f"Test dataset size: {len(dataset['test'])}")

# Tokenize function - create a function that properly formats data
def tokenize_function(examples):
    # Return tokenized examples with proper padding and truncation
    return tokenizer(
        examples["text"], 
        padding="max_length", 
        truncation=True,
        max_length=512,  # Specify max_length explicitly
        return_tensors="pt"  # Return PyTorch tensors
    )

# Process dataset to have the right format and data types
tokenized_datasets = dataset.map(
    tokenize_function, 
    batched=True,
    remove_columns=["text"]  # Remove text column which won't be needed
)

# Convert label column to make it compatible with the model
tokenized_datasets = tokenized_datasets.map(
    lambda examples: {"labels": examples["label"]},
    remove_columns=["label"]  # Remove original label column
)

# Check the structure of our processed dataset
print("\nSample from processed dataset:")
sample = tokenized_datasets["train"][0]
print(f"Type: {type(sample)}")
print(f"Keys: {sample.keys()}")
print(f"Example item: {sample}")

# Define training arguments
# Check the TrainingArguments available parameters
from inspect import signature
print("Available parameters for TrainingArguments:", signature(TrainingArguments))

training_args = TrainingArguments(
    output_dir="./results",
    # Try with no evaluation strategy parameter first
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    # Simplified arguments to minimize potential issues
    logging_dir="./logs"
)

# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
)

# Train model
print("\nStarting training...")
trainer.train()

# Evaluate model
print("\nEvaluating model...")
eval_results = trainer.evaluate()
print(f"Evaluation results: {eval_results}")

# Save model
model_path = "./imdb-bert-classifier"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
print(f"\nModel saved to {model_path}")

Using device: cpu


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Train dataset size: 1600
Test dataset size: 400


Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]

Map:   0%|          | 0/1600 [00:00<?, ? examples/s]

Map:   0%|          | 0/400 [00:00<?, ? examples/s]


Sample from processed dataset:
Type: <class 'dict'>
Keys: dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
Example item: {'input_ids': [101, 2034, 1024, 1045, 4149, 2009, 2012, 1996, 2678, 3573, 1012, 2117, 1024, 1045, 3427, 2009, 1012, 2353, 1024, 2009, 2001, 11771, 1012, 2959, 1024, 2009, 2001, 2025, 6057, 1012, 3587, 1024, 2087, 1997, 1996, 27440, 2020, 20342, 1012, 1998, 2197, 1010, 2021, 2025, 2560, 1024, 2009, 1005, 1055, 2025, 2069, 1037, 2919, 3185, 1010, 2009, 1005, 1055, 1037, 2561, 19807, 9363, 1012, 1026, 7987, 1013, 1028, 1026, 7987, 1013, 1028, 1045, 2572, 1037, 4121, 4205, 5472, 3917, 5470, 2750, 2023, 15640, 1998, 6404, 2143, 1012, 1045, 12063, 2009, 2138, 2009, 2001, 2010, 2034, 3185, 1012, 2130, 2065, 2017, 2024, 1037, 4121, 4205, 5472, 3917, 5470, 1010, 2123, 1005, 1056, 8572, 3666, 2023, 3185, 1012, 2612, 1010, 2074, 2202, 1996, 2678, 1010, 2604, 1037, 12187, 1010, 1998, 5466, 2009, 27089, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0