https://github.com/AnswerDotAI/fsdp_qlora/blob/main/train.py (ref)

In [None]:
# !pip install bitsandbytes
# !pip install hqq
# !pip install wandb
# !pip install datasets
# !pip install peft

In [None]:
import tqdm
from typing import List, Dict, Union
import torch
from torch import nn, Tensor
from bitsandbytes.nn import Linear4bit, Params4bit
import types
from torch.utils.data import Dataset, DataLoader, DistributedSampler
import copy
from transformers.tokenization_utils_fast import PreTrainedTokenizerFast
from torch.nn.utils.rnn import pad_sequence

try:
    from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig
except ImportError:
    HQQLinear = None
    pass

try:
    import wandb
except ImportError:
    pass

In [None]:
class Logger:
    def __init__(self, args, log_to='stdout', project_name='fsdp_qlora',
                 entity=None, group=None, name=None, rank=0):
        self.log_to = log_to
        if self.log_to == 'wandb' and rank==0:
            import wandb
            wandb.init(
                project=project_name,
                entity=entity,
                group=group,
                name=name,
                cofig=args
            )

    def log(self, d:Dict, rank:int):
        if rank != 0:
            return
        if self.log_to == "tqdm":
            for k,v in d.items():
                tqdm.write(f'{k}: {v}')
        elif self.log_to == 'wandb':
            wandb.log(d)
        elif self.log_to == 'stdout':
            for k,v in d.item():
                print(f'{k}: {v}')


In [None]:
def update_progress_bar(progress_bar:tqdm, epoch:int, log_loss:float, log_lr:float, rank:int):
    if rank==0:
        if log_lr >= 0:
            progress_bar.set_description(f"epoch {epoch}, loss {log_loss:.3f}, lr {log_lr:.2e}", refresh=True)
        else:
            progress_bar.set_description(f"epoch {epoch}, loss {log_loss:.3f}", refresh=True)

In [None]:
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:Union[dict,None]=None,
                  skip_modules:List[str]=['lm_head'], **kwargs):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)

        if isinstance(module, nn.Linear) and name not in skip_modules:
            if issubclass(linear_replacement, Linear4bit):
                model._modules[name] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs
                )
            elif issubclass(linear_replacement, HQQLinear):
                model._modules[name] = linear_replacement(module, quant_config, **kwargs)
            else:
                raise ValueError(f"unsupported linear replacement: {type(linear_replacement)}")
    return model

In [None]:
m = nn.Sequential(
    nn.Linear(5, 10),
    nn.Sequential(nn.Linear(10, 50), nn.ReLU())
)
# m = m.to('cuda')
m

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): Sequential(
    (0): Linear(in_features=10, out_features=50, bias=True)
    (1): ReLU()
  )
)

In [None]:
for p in m.parameters():
    print(p.shape, p.dtype, p.requires_grad, p.device)

torch.Size([10, 5]) torch.float32 True cpu
torch.Size([10]) torch.float32 True cpu
torch.Size([50, 10]) torch.float32 True cpu
torch.Size([50]) torch.float32 True cpu


In [None]:
r = replace_linear(m, linear_replacement=Linear4bit)
# r = r.to('cuda')
r

Sequential(
  (0): Linear4bit(in_features=5, out_features=10, bias=True)
  (1): Sequential(
    (0): Linear4bit(in_features=10, out_features=50, bias=True)
    (1): ReLU()
  )
)

In [None]:
for p in r.parameters():
    print(p.shape, p.dtype, p.requires_grad, p.device)

torch.Size([10, 5]) torch.float32 False cpu
torch.Size([10]) torch.float32 True cpu
torch.Size([50, 10]) torch.float32 False cpu
torch.Size([50]) torch.float32 True cpu


In [None]:
def setup_quantized_meta_for_peft(model:nn.Module):
  def temp_to_method(self, *args, **kwargs):
    return self
  for param in model.parameters():
    if isinstance(param, Params4bit):
      param.quant_state._orig_to = param.quant_state.to
      param.quant_state.to = types.MethodType(temp_to_method, param.quant_state)

In [None]:
def setup_quantized_peft_meta_for_training(model:nn.Module):
  for param in model.parameters():
    if isinstance(param, Params4bit) and hasattr(param.quant_state, '_orig_to'):
      param.quant_state.to = param.quant_state._orig_to
      param.quant_state._orig_to = None

In [None]:
name = "nn.Conv2d"
module_key, _, value_key = name.rpartition('.')

print("module_key:", module_key)
print("value_key:", value_key)
print('_', _)

module_key: nn
value_key: Conv2d
_ .


In [None]:
def load_and_quantize(module:nn.Module, name:str, value:Tensor, device:torch.device=None, dtype:torch.dtype=None,
                      skip_names:list[str]=[], is_meta_rank:bool=False, low_memory:bool=True, verbose:bool=False, quant_method:str='bnb'):

  def place_on_device(value):
    if is_meta_rank:
      device = 'meta'
    elif low_memory:
      device = 'cpu'
    return value.to(device=device)

    if any([skip_name in name for skip_name in skip_names]):
      if verbose:
        print(f'skipping {name} because it is in skip_names')
      return

    module_key, _, value_key = name.rpartition('.')
    try:
      submodule = module.get_submodule(module_key)
    except AttributeError as e:
      print(f'module {module_key} not found:\n{e}')
      return

    try:
      if quant_method=='bnb':
          param = submodule.get_parameter(value_key)
          if isinstance(param, Params4bit):
            value = type(param)(value.to(device=device, dtype=dtype).data, **param.__dict__).cuda(device)
            if is_meta_rank:
              value = type(param)(value.data.to('meta'), **value.__dict__)
            elif low_memory:
              value = type(param)(value.data.to("cpu"), **value.__dict__)
          else:
            value = type(param)(place_on_device(value).data)
      elif quant_method=='hqq':
        if isinstance(submodule, HQQLinear):
          if value_key == 'weight':
            submodule.linear_layer.to_empty(device=device)
            submodule.linear_layer.weight.data.copy_(value.to(device=device, dtype=dtype))
            submodule.initialize()

            if is_meta_rank:
              setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to('meta')))
            elif low_memory:
              setattr(submodule, "W_q", nn.Parameter(submodule.W_q.to("cpu")))
            submodule.in_gpu = False

          if value_key == "bias":
            raise ValueError('bias not supported in HQQLinear yet')
        else:
            param = submodule.get_parameter(value_key)
            value = type(param)(place_on_device(value).data)

    except AttributeError:
      # SRC: remove pass
      value = place_on_device(value)

    if HQQLinear is None or not isinstance(submodule, HQQLinear):
      setattr(submodule, value_key, value)


In [None]:
PROMPT_DICT = {
    "prompt_input": (
        "below is an instruction that describes a task, paired with an input that provides further context. "
        "write a response that appropriately completes the request.\n\n"
        "### instruction:\n{instruction}\n\n### input:\n{input}\n\n### response:"
    ),
    "prompt_no_input": (
        "below is an instruction that describes a task. "
        "write a response that appropriately completes the request.\n\n"
        "### instrucion:\n{instruction}\n\n### response:"
    )
}

In [None]:
class InstructionDataset(Dataset):
  def __init__(self, dataset, tokenizer, style='alpaca'):
    self.dataset = dataset
    self.tokenizer = tokenizer
    self.style = style

  def __len__(self):
    return len(self.dataset)

  def __getitem__(self, index):
    IGNORE_INDEX = -100
    if self.style == 'guanaco':
      prompt = self.dataset[index]['text'].split('### Assistant: ')[0]
      example = self.dataset[index]['text']
    elif self.style == 'qna':
      prompt_template = "###Context:\n{context}\n###Question:\n{question}\n###Answer:\n"
      sample = self.dataset[index]
      prompt = prompt_template.format_map(sample)
      example = prompt + sample['answer']
    else: #alpaca
      ann = self.dataset[index]
      if ann.get("input", "") == "":
        prompt = PROMPT_DICT['prompt_no_input'].format_map(ann)
      else:
        prompt = PROMPT_DICT['prompt_input'].format_map(ann)
      example = prompt + ann['output']

    prompt = torch.tensor(
        self.tokenizer.encode(prompt), dtype=torch.int64
    )
    example = self.tokenizer.encode(example)
    example.append(self.tokenizer.eos_token_id)
    example = torch.tensor(
        example, dtype=torch.int64
    )
    labels = copy.deepcopy(example)
    labels[: len(prompt)] = -1
    example_mask = example.ge(0)
    label_mask = labels.ge(0)
    example[~example_mask] = 0
    labels[~label_mask] = IGNORE_INDEX

    return  {
        'input_ids': example.tolist(),
        'labels': labels.tolist(),
        'attention_mask': example_mask.tolist()
    }


In [None]:
def get_dataloader(tokenizer:PreTrainedTokenizerFast, args:Dict):
  from datasets import Dataset, load_dataset

  if args['dataset'] == 'alpaca':
    dataset = load_dataset('yahma/alpaca-cleaned')['train']
  elif args['dataset'] == 'alpaca_sample':
    dataset = load_dataset("yahma/alpaca-cleaned", split='train[:512]')
  elif args['dataset'] == 'dummy':
    dataset = Dataset.from_dict({
        'instruction': ["instruction"]*512,
        'input': ["input"]*512,
        'output': ['output'*10000]*512
    })
  elif args['dataset'] == 'guanaco':
    dataset = load_dataset("timdettmers/openassistant-guanaco", split='train')
  elif args['dataset'] == 'sql':
    dataset = load_dataset("knowrohit07/know_sql")['validation']
    dataset = dataset.shuffle(seed=args['seed'])
    dataset = dataset.select(range(1000,len(dataset)))

  dataset = dataset.select(range(0, len(dataset)-len(dataset)%(args['batch_size']*args['gradient_accumulation_steps'])))

  if args['dataset'] == 'guanaco':
    dataset = InstructionDataset(dataset, tokenizer, style='guanaco')
  elif args['dataset'] == 'sql':
    dataset = InstructionDataset(dataset, tokenizer, style='qna')
  else: #alpaca
    dataset = InstructionDataset(dataset, tokenizer, style='alpaca')

  def collate_fn(batch, with_attention_mask=False):
    input_ids  = [torch.tensor(item['input_ids']) for item in batch]
    attention_masks = [torch.tensor(item['attension_mask']) for item in batch]
    labels = [torch.tensor(item['labels']) for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)[:,:args['context_length']]
    if with_attention_mask:
      attention_masks = pad_sequence(attention_masks, batch_first=True, padding_value=0)[:,:args['context_length']]
    else:
      attention_masks = None
    labels = pad_sequence(labels, batch_first=True, padding_value=-100)[:,:args['context_length']]
    return  {
        'input_ids': input_ids, 'attention_mask': attention_masks, 'labels': labels
    }

  sampler = DistributedSampler(dataset, seed=args['seed'])

  dataloader = DataLoader(dataset, batch_size=args['batch_size'], collate_fn=collate_fn, sampler=sampler)

  return dataloader



In [None]:
import math

def _get_cosine_one_cycle_lr_lambda(
    current_step:int, *, num_warmup_steps:int, num_training_steps:int, min_lr_fraction=0.1
) :
    if current_step < num_warmup_steps:
        return  float(current_step) / float(max(1, num_warmup_steps))
    scale_term = (1-min_lr_fraction)
    progress = float(current_step-num_warmup_steps) / float(max(1, num_training_steps-num_warmup_steps))
    return (math.cos(math.pi*progress)+1) * 0.5 * scale_term + min_lr_fraction

In [None]:
from torch import optim
import functools
from torch.optim.lr_scheduler import LambdaLR

def get_cosine_one_cycle_scheduler(optimizer:optim.Optimizer, num_warmup_steps:int,
                                  num_training_steps:int, min_lr_fraction:float=0.1):
    lr_lambda = functools.partial(
        _get_cosine_one_cycle_lr_lambda,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps,
        min_lr_fraction=min_lr_fraction
    ) 
    return LambdaLR(optimizer, lr_lambda, last_epoch=-1) 

In [None]:
from transformers.optimization import get_linear_schedule_with_warmup

def get_lr_scheduler(optimizer:optim.Optimizer, dataloader:DataLoader,
                    gradient_accumulation_steps:int, args:Dict):
    num_training_steps = args['num_epochs'] * len(dataloader) // gradient_accumulation_steps
    num_warmup_steps = int(num_training_steps*0.1)
    if args['lr_scheduler'] == 'linear':
        lr_scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps)
    elif args['lr_scheduler'] == 'cosine':
        lr_scheduler = get_cosine_one_cycle_scheduler(optimizer, num_warmup_steps, num_training_steps,\
                                                     min_lr_fraction=0.1)
    elif args['lr_scheduler'] == 'constant':
        lr_scheduler = None
    else:
        raise NotImplementedError(f"{args['lr_scheduler']} lr scheduler not implemented yet")
    return lr_scheduler, num_training_steps

In [None]:
def get_optimizer(model:nn.Module, args:Dict):
    if args['optimizer'] == 'adam':
        return optim.Adam(model.parameters(), lr=args['lr'])
    elif args['optimizer'] == 'sgd':
        return optim.SGD(model.parameters(), lr=args['lr'])
    elif args['optimizer'] == 'adadelta':
        return optim.Adadelta(model.parameters(), lr=args['lr'])
    elif args['optimizer'] == 'adamw':
        #SRC: not torch.optim just optim
        return optim.AdamW(model.parameters(), lr=args['lr'], betas=(0.9, 0.95),
                          eps=1e-5, weight_decay=args['wd'])
    else:
        return ValueError('invalid optimizer')

In [None]:
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES, LlamaMLP, LlamaDecoderLayer
from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy
from peft.tuners import PrefixEncoder, PromptEncoder, PromptEmbedding

def get_wrapping_policy(custom_policy:bool=False):
    if custom_policy:
        def lambda_policy_fn(module):
            return (isinstance(module, nn.Sequential) and all(m.weight.requires_grad for m in module))
    else:
        def lambda_policy_fn(module):
            return (
                len(list(module.named_children()))==0
                and getattr(module, 'weight', None) is not None
                and module.weight.requires_grad
            )
    
    def self_attn_policy_fn(module):
        return isinstance(module, tuple(LLAMA_ATTENTION_CLASSES.values()))
    
    def mlp_policy_fn(module):
        return isinstance(module, LlamaMLP)
    
    lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 
    self_attn_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=self_attn_policy_fn)
    mlp_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=mlp_policy_fn)
    transformer_layer_name = LlamaDecoderLayer
    transformer_wrap_policy = functools.partial(
        transformer_auto_wrap_policy,
        transformer_layer_cls=(
            PrefixEncoder,
            PromptEncoder,
            PromptEmbedding,
            transformer_layer_name
        ) 
    )
    policies = [lambda_policy, transformer_wrap_policy]
    if custom_policy:
        policies.extend([self_attn_policy, mlp_policy])
    return functools.partial(_or_policy, policies=policies) 

In [None]:
class LORA(nn.Module):
    #SRC: type is missing here
    def __init__(self, base_layer, lora_rank:int, lora_alpha:float, lora_dropout:float):
        super().__init__()
        self.base_layer = base_layer
        dtype = getattr(base_layer, 'compute_dtype', next(base_layer.paramteres()).dtype)
        device = next(base_layer.paramteres()).device
        lora_A = nn.Linear(base_layer.in_features, lora_rank, bias=False, device=device, dtype=dtype)
        lora_B = nn.Linear(lora_rank, base_layer.out_features, bias=False, device=device, dtype=dtype)
        lora_B.weight.data.zero_()
        
        self.lora_AB = nn.Sequential(lora_A, lora_B)
        
        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(lora_dropout)
        self.scaling = self.lora_alpha/lora_rank
        
    def forward(self, x:torch.Tensor, *args, **kwargs):
        result = self.base_layer(x, *args, **kwargs)
        result = result.clone()
        
        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            expected_dtype = result.dtype
            x = x.to(next(iter(self.lora_B)).weight.dtype)
        
        output = self.loraAB(self.lora_dropout(x))
        if requires_conversion:
            output = output.to(expected_dtype)
        output = output * self.scaling
        
        result += output
        
        return result
        