In [4]:
import os

import torch
from torch.utils.data import DataLoader
from datasets import load_dataset, load_from_disk
import accelerate
from accelerate import (Accelerator,
                        DeepSpeedPlugin)
from accelerate.utils import (LoggerType, DummyOptim, DummyScheduler)
from transformers import (AdamW,
                          AutoTokenizer,
                          AutoModelForCausalLM,
                          get_linear_schedule_with_warmup,
                         get_cosine_schedule_with_warmup,
                          set_seed,
                          AutoConfig,
                          GPT2LMHeadModel)
from tqdm.auto import tqdm
import math
from modular_transformers.train.utils import Group_Texts
from pathlib import Path
import sys

from modular_transformers.models.gpt2.utils import initialize_gpt2_weights
import pickle

from modular_transformers.models.gpt2.configuration_gpt2 import GPT2Config
from modular_transformers.models import components
from modular_transformers.models.loss_utils import l2_reg

import matplotlib.pyplot as plt

import wandb

import random

import numpy as np

MAX_GPU_BATCH_SIZE = 16
EVAL_BATCH_SIZE = 32
CONTEXT_LENGTH = 1024

In [1]:
# Evaluate function
def evaluate(model, eval_dataloader, accelerator):
    model.eval()
    losses = []
    for step, batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
        with torch.no_grad():
            batch = torch.stack(batch['input_ids']).transpose(1, 0)
            outputs = model(batch, labels=batch)
        losses.append(accelerator.gather(outputs.loss))
    loss = torch.mean(torch.stack(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    accelerator.print(f"validation loss: {loss.item()}, validation perplexity {perplexity.item()}")
    return loss.item(), perplexity.item()

In [2]:
bottleneck_dim = 128
normalize_loss = False
logit_multiplier = 1
loss_hooks = {1: "l2_curvature"}
n_layer = 3

In [5]:
# Load the pre-trained BERT model
model = GPT2LMHeadModel.from_pretrained('gpt2')

# Get the state dictionary
state_dict = model.state_dict()

# Save the state dictionary
torch.save(state_dict, 'gpt2_weights.pt')

In [15]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=False)
tokenizer.pad_token = tokenizer.eos_token

config = {'regsize': 768, 'vocab_size': len(tokenizer), 'n_ctx': CONTEXT_LENGTH, 'bos_token_id': tokenizer.bos_token_id,
                    'eos_token_id': tokenizer.eos_token_id, "bottleneck": 768, "n_layer": 12, "loss_hooks": None, "normalize_loss": normalize_loss,
                    "logit_multiplier": logit_multiplier, "inter_multiplier": 1, "n_heads": 12}
                    
config = GPT2Config(config)
model = components.LM(config)
model_state_dict = model.state_dict()
state_dict = torch.load('gpt2_weights.pt')
model.load_state_dict(state_dict, strict=False)

<All keys matched successfully>

In [14]:
print(state_dict.keys())
print(model_state_dict.keys())

odict_keys(['transformer.wte.weight', 'transformer.wpe.weight', 'transformer.h.0.ln_1.weight', 'transformer.h.0.ln_1.bias', 'transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.0.attn.c_attn.weight', 'transformer.h.0.attn.c_attn.bias', 'transformer.h.0.attn.c_proj.weight', 'transformer.h.0.attn.c_proj.bias', 'transformer.h.0.ln_2.weight', 'transformer.h.0.ln_2.bias', 'transformer.h.0.mlp.c_fc.weight', 'transformer.h.0.mlp.c_fc.bias', 'transformer.h.0.mlp.c_proj.weight', 'transformer.h.0.mlp.c_proj.bias', 'transformer.h.1.ln_1.weight', 'transformer.h.1.ln_1.bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.1.attn.c_attn.weight', 'transformer.h.1.attn.c_attn.bias', 'transformer.h.1.attn.c_proj.weight', 'transformer.h.1.attn.c_proj.bias', 'transformer.h.1.ln_2.weight', 'transformer.h.1.ln_2.bias', 'transformer.h.1.mlp.c_fc.weight', 'transformer.h.1.mlp.c_fc.bias', 'transformer.h.1.mlp.c_proj.weight', 'transformer.h.1.mlp.c_pr

In [None]:
#Set checkpoint if needed ------------------------------------------------
chkpoint = None
wandb_id = "amvcstc4"
epoch_buffer = 12

#Set training config --------------------------------------------

data='10M'
batch_size = 64

train_config = {"lr": 0.0006, "num_epochs": 20, "correct_bias": True, "seed": 42, "batch_size": batch_size}
tokenizer = AutoTokenizer.from_pretrained("gpt2", fast=False)
tokenizer.pad_token = tokenizer.eos_token

path = '/om/weka/evlab/ehoseini/MyData/miniBERTa_v2/'
grouped_pad_train = load_from_disk(
    os.path.join(path, f'miniBERTa-{data}-crunched',
                    f'train_context_len_{CONTEXT_LENGTH}'))
grouped_pad_valid = load_from_disk(
    os.path.join(path, f'miniBERTa-{data}-crunched',
                    f'valid_context_len_{CONTEXT_LENGTH}'))

# If the batch size is too big we use gradient accumulation
gradient_accumulation_steps = 1
if train_config['batch_size'] > MAX_GPU_BATCH_SIZE:
    gradient_accumulation_steps = train_config['batch_size'] // MAX_GPU_BATCH_SIZE
    batch_size = MAX_GPU_BATCH_SIZE
    accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps)
else:
    accelerator = Accelerator()

eval_dataloader = DataLoader(grouped_pad_valid, shuffle=False, batch_size=EVAL_BATCH_SIZE)
train_dataloader = DataLoader(grouped_pad_train, shuffle=True, batch_size=batch_size)
del grouped_pad_train, grouped_pad_valid

#set model config ---------------------------------------

config = {'regsize': 128, 'vocab_size': len(tokenizer), 'n_ctx': CONTEXT_LENGTH, 'bos_token_id': tokenizer.bos_token_id,
                    'eos_token_id': tokenizer.eos_token_id, "bottleneck": bottleneck_dim, "n_layer": n_layer, "loss_hooks": loss_hooks, "normalize_loss": normalize_loss,
                    "logit_multiplier": logit_multiplier, "inter_multiplier": 1, "n_heads": 4}
                    
config = GPT2Config(config)
model = components.LM(config)

torch.cuda.empty_cache()

model = model.to(accelerator.device)

# Define optimizer
# Creates Dummy Optimizer if `optimizer` was specified in the config file else creates AdamW Optimizer
optimizer_cls = (torch.optim.AdamW
    if accelerator.state.deepspeed_plugin is None
        or "optimizer" not in accelerator.state.deepspeed_plugin.deepspeed_config
    else DummyOptim
)
optimizer = optimizer_cls(params=model.parameters(), lr=train_config['lr'])
if (
        accelerator.state.deepspeed_plugin is None
        or "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config
):
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=100,
        num_training_steps=(len(train_dataloader) * train_config['num_epochs']),
    )
else:
    assert False

# Pass everything to accelerator
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler)

data_count = 0
absolute_step = 0

real_epochs = 7

for epoch in tqdm(range(real_epochs)):

    model.train()
    torch.cuda.empty_cache()
    for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
        batch = [torch.stack(batch[x]).transpose(1, 0) for x in ['input_ids', 'attention_mask']]

        with accelerator.accumulate(model):
            outputs = model(batch[0], labels=batch[0], attention_mask=batch[1])
            logit_loss = outputs.loss

            if loss_hooks is not None:
                extra_losses = model.output_extra_losses()
                extra_loss = 0
                for loss in extra_losses.values():
                    if loss is not None:
                        extra_loss += loss
                                    
                logit_loss = logit_loss * logit_multiplier
                extra_loss = extra_loss
                loss = logit_loss + extra_loss
                logit_loss = logit_loss.item()
                extra_loss = extra_loss.item()
            else:
                loss = logit_loss
                            
            accelerator.backward(loss)
            lr_scheduler.step()
            optimizer.step()
            optimizer.zero_grad()

        data_count += batch[0].shape[0]

        absolute_step += 1

        torch.cuda.empty_cache()
            
accelerator.end_training()


### Testing one sample

In [None]:
loss_hooks = {4: "l0_curvature"}
config = {'regsize': 768, 'vocab_size': len(tokenizer), 'n_ctx': 1024, 'bos_token_id': tokenizer.bos_token_id,
                    'eos_token_id': tokenizer.eos_token_id, "bottleneck": 768, "n_layer": 12, "loss_hooks": loss_hooks, "normalize_loss": False,
                    "logit_multiplier": 1, "inter_multiplier": 1, "n_heads": 12, "pretrained": False, "warmup": False}
                    
config = GPT2Config(config)
model = components.LM(config)

model = model.to(device)
model = model.eval()

path = '/om/weka/evlab/ehoseini/MyData/miniBERTa_v2/'
data_size = "10M"
data = load_from_disk(
    os.path.join(path, f'miniBERTa-{data_size}-crunched',
                    f'train_context_len_{1024}'))

dataloader = DataLoader(data, shuffle=True, batch_size=2)
batch = next(iter(dataloader))
batch = [torch.stack(batch[x]).transpose(1, 0) for x in ['input_ids', 'attention_mask']]
outputs = model(batch[0], labels=batch[0], attention_mask=batch[1])
