In [15]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset, DatasetDict, Dataset

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Load the SciQ dataset
dataset = load_dataset('sciq')

# Define the model and tokenizer
model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(model_name)

cuda


In [16]:
from datasets import DatasetDict
# Define a function to tokenize the data
def preprocess_function(examples):
    questions = examples['question']
    distractors1 = examples['distractor1']
    distractors2 = examples['distractor2']
    distractors3 = examples['distractor3']
    correct_answers = examples['correct_answer']
    
    input_ids = []
    attention_mask = []
    labels = []
    
    # Process each question with its options
    for i in range(len(questions)):
        # Create a list of options including the correct answer
        texts = [
            f"{questions[i]} {distractors1[i]}",
            f"{questions[i]} {distractors2[i]}",
            f"{questions[i]} {distractors3[i]}",
            f"{questions[i]} {correct_answers[i]}"
        ]
        
        # Tokenize each option
        tokenized = tokenizer(texts, truncation=True, padding='max_length', max_length=512, return_tensors='pt')
        
        # Extract tokenized tensors and labels
        input_ids.extend(tokenized['input_ids'].squeeze(0).tolist())  # Convert to list for each batch
        attention_mask.extend(tokenized['attention_mask'].squeeze(0).tolist())
        labels.extend([0, 1, 2, 3])  # Label each option
    
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'labels': labels
    }

# Limit dataset to the first 1000 rows
def slice_dataset(dataset, num_rows=1000):
    train_data = dataset['train'].select(range(num_rows))
    validation_data = dataset['validation'].select(range(num_rows))
    return DatasetDict({
        'train': train_data,
        'validation': validation_data
    })

# Slice the dataset
limited_dataset = slice_dataset(dataset)

# Convert the dataset to a format suitable for classification
def prepare_dataset(dataset):
    # Process the training and validation splits separately
    tokenized_train = dataset['train'].map(preprocess_function, batched=True, remove_columns=['question', 'distractor1', 'distractor2', 'distractor3', 'correct_answer', 'support'])
    tokenized_validation = dataset['validation'].map(preprocess_function, batched=True, remove_columns=['question', 'distractor1', 'distractor2', 'distractor3', 'correct_answer', 'support'])
    
    # Ensure the datasets are correctly formatted
    return DatasetDict({
        'train': tokenized_train,
        'validation': tokenized_validation
    })

formatted_datasets = prepare_dataset(limited_dataset)

In [17]:
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
# Set up the training arguments
training_args = TrainingArguments(
    output_dir='./results',
    evaluation_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)

# Initialize the Trainer
trainer = Trainer(
    model=BertForSequenceClassification.from_pretrained(model_name, num_labels=4),
    args=training_args,
    train_dataset=formatted_datasets['train'],
    eval_dataset=formatted_datasets['validation'],
    tokenizer=tokenizer
)

# Train the model
trainer.train()

# Save the model
trainer.model.save_pretrained('./sciqa-bert')
tokenizer.save_pretrained('./sciqa-bert')

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


KeyboardInterrupt: 

In [8]:
print(dataset['train'][0])

{'question': 'What type of organism is commonly used in preparation of foods such as cheese and yogurt?', 'distractor3': 'viruses', 'distractor1': 'protozoa', 'distractor2': 'gymnosperms', 'correct_answer': 'mesophilic organisms', 'support': 'Mesophiles grow best in moderate temperature, typically between 25°C and 40°C (77°F and 104°F). Mesophiles are often found living in or on the bodies of humans or other animals. The optimal growth temperature of many pathogenic mesophiles is 37°C (98°F), the normal human body temperature. Mesophilic organisms have important uses in food preparation, including cheese, yogurt, beer and wine.'}
