In [1]:
#imports
from accelerate import Accelerator, notebook_launcher
import torch
from torch.utils.data import DataLoader
from transformers import DataCollatorForLanguageModeling, OPTForCausalLM, Trainer, TrainingArguments
from tqdm import tqdm
from transformers import AutoTokenizer
from datasets import load_dataset
from utils.save_utils import load_masked_model, load_masked_model_single
from utils.prehook_utils import remove_all_hooks

In [2]:
def training_step():
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    accelerate = Accelerator()
    
    def encode_tok(examples):
        return tokenizer(examples['text'], truncation=True, padding='max_length')


    model_name='opt-1.3b'
    EPOCH_COUNT=10
    SPARSITY=0.2
    tokenizer = AutoTokenizer.from_pretrained(f'facebook/{model_name}', padding_side='left')

    #stream c4, training split
    #training_data = load_dataset('c4', 'en', split='train', streaming=True)
    training_data = load_dataset('wikitext', 'wikitext-2-raw-v1', split='train', streaming=True)
    #IMPORTANT: process data while streaming -> remove unnecessary columns in batches
    training_data = training_data.map(encode_tok, 
                                      batched=True,
                                      remove_columns=["text", "timestamp", "url"])
    #set data to tensor mode
    training_data = training_data.with_format("torch")

    #dataloader from dataloader (mlm=False when training without mask)
    dataloader = DataLoader(training_data, 
                            collate_fn=DataCollatorForLanguageModeling(tokenizer, mlm=False),
                            batch_size=1)
    #print(torch.cuda.is_initialized())
    loaded_model = OPTForCausalLM.from_pretrained(f'facebook/{model_name}',
                                                      output_attentions=True,
                                                      output_hidden_states=True)
    #print(torch.cuda.is_initialized())
    if SPARSITY != 1:
        load_masked_model_single(loaded_model, f'pruned_models/{model_name}-{SPARSITY}.pt')
    #print(torch.cuda.is_initialized())
    t_optim = torch.optim.AdamW(params=loaded_model.parameters(), lr=1e-5)


    loaded_model, optimizer, training_data = accelerate.prepare(loaded_model, t_optim, dataloader)
    !nvidia-smi
    loaded_model.train()
    t_optim = torch.optim.AdamW(params=loaded_model.parameters(), lr=1e-5)
    for epoch in tqdm(range(EPOCH_COUNT)):
        #training_data.set_epoch(epoch)
        for i, batch in enumerate(dataloader):
            print('TRAINING')
            !nvidia-smi
            #print(batch)
            if i == 5:
                break
            batch = {k: torch.tensor(v) for k, v in batch.items()}
            #print(batch)
            #print(batch['input_ids'].size())
            #print(batch['attention_mask '].size())
            outputs = loaded_model(**batch)
            loss = outputs.loss
            print(loss)
            accelerate.backward(loss)
            #loss.backward()
            t_optim.step()
            t_optim.zero_grad()
    torch.save(loaded_model.state_dict(), f'pruned_models/{model_name}-{SPARSITY}-finetuned.pt')

In [3]:
notebook_launcher(training_step, args=(), num_processes=4, mixed_precision='bf16')

Launching training on 4 GPUs.
