In [2]:
# 1. Data Loading

from datasets import load_dataset

# Load WikiText-2
wiki = load_dataset("wikitext", "wikitext-2-raw-v1")

# Load OpenWebText
openweb = load_dataset("openwebtext", trust_remote_code=True)

Downloading data:   0%|          | 0/21 [00:00<?, ?files/s]

urlsf_subset00.tar:   0%|          | 0.00/633M [00:00<?, ?B/s]

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


urlsf_subset01.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset02.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset03.tar:   0%|          | 0.00/628M [00:00<?, ?B/s]

urlsf_subset04.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

urlsf_subset05.tar:   0%|          | 0.00/630M [00:00<?, ?B/s]

urlsf_subset06.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

urlsf_subset07.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset08.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset09.tar:   0%|          | 0.00/626M [00:00<?, ?B/s]

urlsf_subset10.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset11.tar:   0%|          | 0.00/625M [00:00<?, ?B/s]

urlsf_subset12.tar:   0%|          | 0.00/624M [00:00<?, ?B/s]

urlsf_subset13.tar:   0%|          | 0.00/629M [00:00<?, ?B/s]

urlsf_subset14.tar:   0%|          | 0.00/627M [00:00<?, ?B/s]

urlsf_subset15.tar:   0%|          | 0.00/621M [00:00<?, ?B/s]

urlsf_subset16.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset17.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset18.tar:   0%|          | 0.00/618M [00:00<?, ?B/s]

urlsf_subset19.tar:   0%|          | 0.00/619M [00:00<?, ?B/s]

urlsf_subset20.tar:   0%|          | 0.00/377M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8013769 [00:00<?, ? examples/s]

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

KeyError: 'validation'

In [4]:
from datasets import concatenate_datasets

# Manually split OpenWebText into train + validation
openweb_split = openweb["train"].train_test_split(test_size=0.1, seed=42)

# Combine train splits
combined_train = concatenate_datasets([wiki["train"], openweb_split["train"]])
combined_valid = concatenate_datasets([wiki["validation"], openweb_split["test"]])

# Shuffle + select a subset
combined_train = combined_train.shuffle(seed=42).select(range(100_000))
combined_valid = combined_valid.shuffle(seed=42).select(range(10_000))

In [7]:
# Make the dataset
from datasets import DatasetDict

dataset = DatasetDict({
    "train": combined_train,       # from earlier steps
    "validation": combined_valid,
})

In [12]:
# 2. Tokenization

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
def tokenize(text):
    return tokenizer(
        text["text"], 
        return_special_tokens_mask=True,
        truncation=True,           
        max_length=1024)

tokenized_dataset = dataset.map(tokenize, batched=True, remove_columns=["text"])

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [13]:
block_size = 128

def group_texts(examples):
    concatenated = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = (len(concatenated["input_ids"]) // block_size) * block_size
    
    result = {}
    for k in concatenated.keys():
        chunks = [concatenated[k][i:i+block_size] for i in range(0, total_length, block_size)]
        result[k] = chunks
    
    result["labels"] = result["input_ids"].copy()
    return result

lm_dataset = tokenized_dataset.map(group_texts, batched=True)

Map:   0%|          | 0/100000 [00:00<?, ? examples/s]

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]

In [14]:
# 3. Model Selection

from transformers import GPT2LMHeadModel

model = GPT2LMHeadModel.from_pretrained("gpt2")
model.resize_token_embeddings(len(tokenizer)) 

Embedding(50257, 768)

In [15]:
import torch, transformers
print("PyTorch device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")
print("transformers device map:", model.hf_device_map if hasattr(model, 'hf_device_map') else "n/a")

PyTorch device: CPU
transformers device map: n/a


In [18]:
# 4. Fine-Tuning

from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

training_args = TrainingArguments(
    output_dir="./next-word-model",
    eval_strategy="epoch",
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    num_train_epochs=3,
    logging_dir='./logs',
    learning_rate=3e-5,
    fp16=False,
    report_to="none",
    dataloader_num_workers=2,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_dataset["train"],
    eval_dataset=lm_dataset["validation"],
    data_collator=data_collator,
)

trainer.args.max_steps = 6000
trainer.train()

Epoch,Training Loss,Validation Loss
0,3.4878,3.368439


TrainOutput(global_step=6000, training_loss=3.4882898763020833, metrics={'train_runtime': 26605.6773, 'train_samples_per_second': 0.451, 'train_steps_per_second': 0.226, 'total_flos': 783876096000000.0, 'train_loss': 3.4882898763020833, 'epoch': 0.022189102931550315})

In [19]:
# 4. Evaluation

import math
eval_results = trainer.evaluate()
print(f"Perplexity: {math.exp(eval_results['eval_loss']):.2f}")

Perplexity: 29.03


In [20]:
# Optional Extension: Gradio

import gradio as gr
import torch

def predict_next_word(text):
    inputs = tokenizer(text, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    logits = outputs.logits[:, -1, :]
    predicted_id = torch.argmax(logits, dim=-1).item()
    return tokenizer.decode([predicted_id])

gr.Interface(fn=predict_next_word, inputs="text", outputs="text", title="Next Word Predictor").launch()

* Running on local URL:  http://127.0.0.1:7860
* To create a public link, set `share=True` in `launch()`.


