In [None]:
import torch
import torch.optim as optim

from ava.config import AvaConfig
from ava.model import AvaForCausalLM
from ava.data.tokenizer import SimpleTokenizer
from ava.data.data_utils import prepare_data_from_json

from torch.utils.data import DataLoader
from transformers import AutoTokenizer

In [2]:
tokenizer = SimpleTokenizer(AutoTokenizer.from_pretrained('gpt2'))
config = AvaConfig(vocab_size = tokenizer.tokenizer.vocab_size)
config.apply_for('small')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AvaForCausalLM(config).to(device)

In [3]:
num_params = sum(p.numel() for p in model.parameters())
print(f'Model has {num_params:,} parameters')

Model has 129,109,248 parameters


In [4]:
train_dataset, val_dataset = prepare_data_from_json(
    './sample_dialogue_data.json', 
    tokenizer.tokenizer, 
    max_length = config.max_position_embeddings
)

In [8]:
train_dataloader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False
)

In [None]:
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)

In [None]:
total_steps = len(train_dataloader) * args.epochs

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * total_steps),
    num_training_steps=total_steps
)

model = train_model(
    model=model,
    train_dataloader=train_dataloader,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=args.epochs,
    device=device,
    eval_dataloader=val_dataloader
)

torch.save({
    'model_state_dict': model.state_dict(),
    'config': config.to_dict(),
}, 'ava_model_final.pt')

print("Training complete! Model saved as 'ava_model_final.pt'")