In [None]:
import evaluate
import torch
from torch.utils.data import DataLoader
from torch.optim import AdamW
# from transformers import GPT2Config, GPT2Model
from transformers import AutoTokenizer, GPT2LMHeadModel, DataCollatorWithPadding, DataCollatorForLanguageModeling

In [None]:
# TODO change to gpt2-large
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
model = GPT2LMHeadModel.from_pretrained("openai-community/gpt2")
model

In [None]:
model.config

In [None]:
print(f" Num parameters = {model.num_parameters() / 1e6} million")

In [None]:
from datasets import load_dataset, load_dataset_builder
# ds_builder = load_dataset_builder('wikitext', 'wikitext-2-v1')
# ds_builder.info

In [None]:
dataset = load_dataset("wikitext", "wikitext-2-v1")
dataset


In [None]:
dataset['train'][3]

In [None]:
dataset['train'].features

In [None]:
# return_overflowing_tokens=True? return_length=True for filtering?
def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=["text"])

In [None]:
tokenized_datasets['train'][1]

In [None]:
with torch.no_grad():
    print(model(torch.tensor([tokenized_datasets['train'][1]['input_ids']]))['logits'].shape)

In [None]:
# block_size = tokenizer.model_max_length
block_size = int(tokenizer.model_max_length / 4) # TODO: change

def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

# Note that with `batched=True`, this map processes 1,000 texts together, so group_texts throws away a remainder
# for each of those groups of 1,000 texts. You can adjust that batch_size here but a higher value might be slower
# to preprocess.
lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=1,
    desc=f"Grouping texts in chunks of {block_size}"
)

In [None]:
train_dataset = lm_datasets['train']
eval_dataset = lm_datasets['validation']
print(f"train: {len(train_dataset)}, val: {len(eval_dataset)}")

In [None]:
# data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# NOTE: tokenizer does not have a pad token.
tokenizer.pad_token = tokenizer.eos_token
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
# TODO: use numworkers and pinmemory.
dl_kwargs = {
    'batch_size': 8,
    'collate_fn': data_collator,
    'num_workers': 2,
    'pin_memory': True,
}
train_dl = DataLoader(train_dataset, shuffle=True, batch_size=8, collate_fn=data_collator)
eval_dl = DataLoader(eval_dataset, batch_size=8, collate_fn=data_collator)

In [None]:
for batch in train_dl:
    for k, v in batch.items():
        print(k, v.shape, v)
    break

In [None]:
# Hyperparams
WEIGHT_DECAY = 0.1
LEARNING_RATE = 5e-5

In [None]:
# FSDP
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
    CPUOffload,
    BackwardPrefetch,
)
from torch.utils.data.distributed import DistributedSampler
from torch.distributed.fsdp.wrap import (
    size_based_auto_wrap_policy,
    transformer_auto_wrap_policy, # TODO migrate to this.
    enable_wrap,
    wrap,
)

from functools import partial

auto_wrap_policy = partial(size_based_auto_wrap_policy, min_num_params=int(1e6))
model = FSDP(model, auto_wrap_policy=auto_wrap_policy)
model

In [None]:
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
# no_decay = ["bias", "layer_norm.weight"]
# optimizer_grouped_parameters = [
#     {
#         "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
#         "weight_decay": WEIGHT_DECAY,
#     },
#     {
#         "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
#         "weight_decay": 0.0,
#     },
# ]
# optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)


In [None]:
with torch.no_grad():
    for batch in train_dl:
        out = model(**batch)
        break


In [None]:
def train():
    pass

def validate(model: torch.nn.Module, val_dl: DataLoader):
    accmetric = evaluate.load("accuracy", module_type="metric")
    # TODO add perplexity and other metrics?
    # perpmetric = evaluate.load("perplexity", module_type="metric")
    total_loss = 0.
    model.eval()
    for batch in val_dl:
        with torch.no_grad():
            out = model(**batch)
        total_loss += out.loss.item()
        predictions = torch.argmax(out.logits, dim=-1)
        accmetric.add_batch(prediction=predictions, reference=batch["labels"])

    results = {
        "val_accuracy": accmetric['accuracy'],
        "val_loss": total_loss / len(val_dl),
    }


    

In [None]:
import os
import torch.distributed as dist

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = "localhost"
    os.environ['MASTER_PORT'] = 12355
    dist.init_process_group("nccl", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

In [None]:
from torch.utils. data import Dataset

type(train_dataset)