In [1]:
import os
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from transformers.tokenization_utils_base import BatchEncoding
from peft import PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType
import re

CONFIG = {
    'FROM_PEFT': False,
    'MODEL_NAME': "facebook/opt-1.3b",
    'PEFT_MODEL_NAME': './opt_9_epoch_lora',
    'ADD_CONTEXT': True,
    'TEST_SIZE': .1,
    'LR': 2e-4,
    'NUM_EPOCHS': 1,
    'WANDB_PROJECT': 'OPT-1.3b-fine-tuning-context',
    'MODEL_FOLDER': 'opt_with_context',
    'ANSWERS_OUTPUT_FILE': 'answers_with_context.json'
}

if CONFIG['FROM_PEFT']:
    peft_model_id = CONFIG['PEFT_MODEL_NAME']
    config = PeftConfig.from_pretrained(peft_model_id)
    model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path)
    model = PeftModel.from_pretrained(model, peft_model_id)
    tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path)
else:
    model = AutoModelForCausalLM.from_pretrained(
        CONFIG['MODEL_NAME'], 
    )
    peft_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
    )

    model = get_peft_model(model, peft_config)

    tokenizer = AutoTokenizer.from_pretrained(CONFIG['MODEL_NAME'])
    
model.print_trainable_parameters()
model = model.to('mps')

trainable params: 1572864 || all params: 1317330944 || trainable%: 0.11939778740975206


In [2]:
import wandb
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mkonstantzts[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [3]:
def create_sample(data, session, sessions, add_context=True, phrase_idx=None, person_replacer=None):
    dialog = '\n'.join(data[session]['dialog'][:phrase_idx])
    context = '\n'.join([data[session]['context'] for session in sessions*add_context])
    answer = ''
    if phrase_idx:
        answer = data[session]['dialog'][phrase_idx]
    facts = dict()
    if add_context:
        for cur_session in sessions:
            for person, person_facts in data[cur_session]['facts'].items():
                if person not in facts:
                    facts[person] = []
                facts[person] += person_facts

    if person_replacer:
        for to_replace, (pattern, person_mapper) in person_replacer.items():
            dialog = re.sub(pattern, to_replace, dialog)
            context = re.sub(pattern, to_replace, context)
            answer = re.sub(pattern, to_replace, answer)
            for person in person_mapper:
                if person in facts:
                    facts[to_replace] = facts.pop(person)
            for person, person_facts in facts.items():
                facts[person] = [re.sub(pattern, to_replace, fact) for fact in facts[person]]
    
    sample = {'dialog': dialog, 'context': context, 'answer': answer, 'facts': facts}
    return sample

def create_dataset(dataset, full_dialog=True, add_context=True, person_replacer=None):
    new_dataset = []
    
    sessions = ['session_1', 'session_2', 'session_3']
    for data in dataset:
        for i, session in enumerate(sessions[1:]):
            if full_dialog:
                new_dataset.append(create_sample(data, session, sessions[:i+1], add_context, None, person_replacer))
            else:
                for phrase_idx in range(1, len(data[session]['dialog']), 2):
                    new_dataset.append(create_sample(data, session, sessions[:i+1], add_context, phrase_idx, person_replacer))
                
    return new_dataset

In [4]:
import json

with open('dialog_dataset.json', 'r') as f:
    dataset = list(json.load(f).values())
    
TEST_SIZE = .1

replacement_patterns = {'Human': (r'[Bb]ot[_\s]?0', ['bot_0']), 
                        'Assistant': (r'[Bb]ot[_\s]?1', ['bot_1']), 
                        'Person': (r'\b[Bb]ot\b', []), 'Persons': (r'\b[Bb]ots\b', [])}

train_data = create_dataset(dataset[:int(len(dataset)*(1-CONFIG['TEST_SIZE']))], full_dialog=True, 
                            add_context=CONFIG['ADD_CONTEXT'], person_replacer=replacement_patterns)
valid_data = create_dataset(dataset[int(len(dataset)*(1-CONFIG['TEST_SIZE'])):], full_dialog=True, 
                            add_context=CONFIG['ADD_CONTEXT'], person_replacer=replacement_patterns)
data_for_inference = create_dataset(dataset, full_dialog=False, 
                                    add_context=CONFIG['ADD_CONTEXT'], person_replacer=replacement_patterns)

In [5]:
from torch.utils.data import Dataset, DataLoader

class SequenceDataset(Dataset):

    def __init__(self, data):
        super().__init__()

        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):

        sample = self.data[index]

        dialog = 'Dialog:\n' + sample['dialog'] + ('\nAssistant:' if sample['answer'] else '')
        context = ('Context:\n' + sample['context'] + '\n\n')*(bool(sample['context']))
        answer = sample['answer'][len('Assistant:'):]
        facts = ('\n\n'.join([f'Facts about {person}:\n' + '\n'.join(facts) 
                             for person, facts in sample['facts'].items()]) + '\n\n')*(bool(sample['context']))

        return context, facts, dialog, answer

In [6]:
class Collator:

    def __init__(self, tokenizer, inference_mode=False, max_length=1024):

        self.tokenizer = tokenizer

        self.max_length = max_length
        
        self.inference_mode = inference_mode

    def __call__(self, batch):

        inputs = []
        targets = []

        max_len = 0
        for context, facts, dialog, answer in batch:
            tokenized_context = tokenizer.encode(context, add_special_tokens=False)
            tokenized_facts = tokenizer.encode(facts, add_special_tokens=False)
            tokenized_dialog = tokenizer.encode(dialog, add_special_tokens=False)
            tokenized_answer = tokenizer.encode(answer, add_special_tokens=False)*(not self.inference_mode)
            
            input_tokens = [self.tokenizer.bos_token_id] + tokenized_context + tokenized_facts + tokenized_dialog + tokenized_answer
            
            if tokenized_answer:
                pad_sequence_len = len(input_tokens) - 1 - len(tokenized_answer)
                target = tokenized_answer
            else:
                pad_sequence_len = len(input_tokens) - 1 - len(tokenized_dialog)
                target = tokenized_dialog
            target_tokens = [self.tokenizer.pad_token_id]*pad_sequence_len + target + [self.tokenizer.eos_token_id]
            
            assert len(input_tokens) == len(target_tokens)
            
            inputs.append(input_tokens)
            targets.append(target_tokens)
            
            if len(inputs[-1]) > max_len:
                max_len = len(input_tokens)

        inputs = [tokens + [self.tokenizer.pad_token_id]*(max_len - len(tokens)) for tokens in inputs]
        targets = [tokens + [self.tokenizer.pad_token_id]*(max_len - len(tokens)) for tokens in targets]

        inputs = torch.tensor([tokens[max(0, len(tokens) - self.max_length):] for tokens in inputs])
        targets = torch.LongTensor([tokens[max(0, len(tokens) - self.max_length):] for tokens in targets])

        inputs = BatchEncoding({'input_ids': inputs, 'attention_mask': (inputs != self.tokenizer.pad_token_id).long()})
        
        return inputs, targets

In [7]:
train_dataset = SequenceDataset(data=train_data)
valid_dataset = SequenceDataset(data=valid_data)
dataset = SequenceDataset(data=data_for_inference)

train_collator = Collator(tokenizer=tokenizer)
inference_collator = Collator(tokenizer=tokenizer, inference_mode=True)

train_loader = DataLoader(train_dataset, batch_size=1, collate_fn=train_collator, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=1, collate_fn=train_collator, shuffle=False)
data_loader = DataLoader(dataset, batch_size=1, collate_fn=inference_collator, shuffle=False)

In [8]:
LR = 2e-4
NUM_EPOCHS = 1

wandb.init(
    # set the wandb project where this run will be logged
    project=CONFIG['WANDB_PROJECT'],
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": CONFIG['LR'],
    "epochs": CONFIG['NUM_EPOCHS'],
    }
)

criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)

optimizer = torch.optim.AdamW(params=model.parameters(), lr=CONFIG['LR'])

In [9]:
from tqdm import tqdm
import numpy as np

def loop(n_epoch, is_train, loader, grad_acum_steps=1):

    if is_train:
        model.train()
    else:
        model.eval()

    all_predictions = list()
    all_targets = list()

    losses = list()

    progress_bar = tqdm(total=len(loader) // grad_acum_steps if is_train else len(loader), 
                        desc="Train" if is_train else "Valid")

    if is_train:
        model.train()
    else:
        model.eval()

    losses = list()

    for n_step, (batch, targets) in enumerate(loader):
        torch.mps.empty_cache()

        batch = batch.to(model.model.device)
        targets = targets.to(model.model.device)

        if is_train:
            logits = model(**batch).logits
        else:
            with torch.inference_mode():
                logits = model(**batch).logits

        loss = criterion(logits.view(-1, logits.size(-1)), targets.contiguous().view(-1))

        losses.append(loss.item())

        if is_train:
            loss.backward()
            if n_step > 0 and n_step % grad_acum_steps == 0:
                torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=0.7)
                optimizer.step()
                progress_bar.update()
                progress_bar.set_postfix(loss=np.mean(losses[-100:]))
        else:
            progress_bar.update()
            progress_bar.set_postfix(loss=np.mean(losses[-100:]))

    progress_bar.close()

    return losses

In [10]:
for n_epoch in range(CONFIG['NUM_EPOCHS']):

    train_losses = loop(n_epoch, is_train=True, loader=train_loader, grad_acum_steps=10)
    valid_losses = loop(n_epoch, is_train=False, loader=valid_loader, grad_acum_steps=10)
    
    train_mean_loss = np.mean(train_losses)
    valid_mean_loss = np.mean(valid_losses)
    
    wandb.log({"train perplexity": np.exp(train_mean_loss), "train loss": train_mean_loss, 
               "validation perplexity": np.exp(valid_mean_loss), "validation loss": valid_mean_loss})

    epoch_message = [
        f"Epoch {n_epoch} done",
        "",
        "Train",
        f"\tLoss: {train_mean_loss:.3f}",
        "Valid",
        f"\tLoss: {valid_mean_loss:.3f}",
    ]

    print("\n".join(epoch_message))

    model.save_pretrained(f"./{CONFIG['MODEL_FOLDER']}/{n_epoch}_epoch_lora") 

Train: 100%|██████████| 10/10 [17:09<00:00, 102.93s/it, loss=2.38]
Valid: 100%|██████████| 12/12 [00:17<00:00,  1.42s/it, loss=2.24]

Epoch 0 done

Train
	Loss: 2.380
Valid
	Loss: 2.242





In [11]:
wandb.run.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train loss,▁
train perplexity,▁
validation loss,▁
validation perplexity,▁

0,1
train loss,2.37959
train perplexity,10.80047
validation loss,2.24155
validation perplexity,9.4079


In [12]:
model.to('cpu')

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): OPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 2048, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 2048)
          (final_layer_norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0-23): 24 x OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
                (v_proj): Linear(
                  in_features=2048, out_features=2048, bias=True
                  (lora_dropout): Dropout(p=0.1, inplace=False)
                  (lora_A): Linear(in_features=2048, out_features=8, bias=False)
                  (lora_B): Linear(in_features=8, out_features=2048, bias=False)
                )
                (q_proj): Linear(
                  in_features=2048, out_features=2048, bias=True
        

In [13]:
from tqdm import tqdm

answers = []

model.eval()
for x, _ in tqdm(data_loader):
    with torch.inference_mode():
        outputs = model.generate(**x, max_new_tokens=64, do_sample=True, temperature=.9, eos_token_id=50118)
        outputs = outputs[0].detach().cpu()
    answers.append(tokenizer.decode(outputs, skip_special_tokens=True))

  8%|▊         | 55/677 [13:04<2:27:54, 14.27s/it]


KeyboardInterrupt: 

In [14]:
if os.path.exists(CONFIG['ANSWERS_OUTPUT_FILE']):
    raise Exception(f"{CONFIG['ANSWERS_OUTPUT_FILE']} is already exists")
else:
    with open(CONFIG['ANSWERS_OUTPUT_FILE'], 'w') as f:
        json.dump(answers, f)