In [None]:
# !pip install bitsandbytes
# !pip install hqq

In [None]:
import tqdm
from typing import List, Dict, Union
import torch
from torch import nn
from bitsandbytes.nn import Linear4bit

try:
    from hqq.core.quantize import HQQLinear, HQQBackend, BaseQuantizeConfig
except ImportError:
    HQQLinear = None
    pass

try:
    import wandb
except ImportError:
    pass

In [None]:
class Logger:
    def __init__(self, args, log_to='stdout', project_name='fsdp_qlora',
                 entity=None, group=None, name=None, rank=0):
        self.log_to = log_to
        if self.log_to == 'wandb' and rank==0:
            import wandb
            wandb.init(
                project=project_name,
                entity=entity,
                group=group,
                name=name,
                cofig=args
            )
    
    def log(self, d:Dict, rank:int):
        if rank != 0:
            return 
        if self.log_to == "tqdm":
            for k,v in d.items():
                tqdm.write(f'{k}: {v}')
        elif self.log_to == 'wandb':
            wandb.log(d)
        elif self.log_to == 'stdout':
            for k,v in d.item():
                print(f'{k}: {v}')
            

In [None]:
def update_progress_bar(progress_bar:tqdm, epoch:int, log_loss:float, log_lr:float, rank:int):
    if rank==0:
        if log_lr >= 0:
            progress_bar.set_description(f"epoch {epoch}, loss {log_loss:.3f}, lr {log_lr:.2e}", refresh=True)
        else:
            progress_bar.set_description(f"epoch {epoch}, loss {log_loss:.3f}", refresh=True)   

In [None]:
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:dict|None=None,
                   skip_modules:List[str]=["lm_head"], **kwargs):
    """
    Replace linear modules with a new Linear module.
    Parameters:
        model (`torch.nn.Module`):
            Input model or `torch.nn.Module` as the function is run recursively.
        linear_replacement (`torch.nn.Module`):
            The linear module that replaces the old one. Only expects standard arguments.
            If other arguments need to be passed, use a lambda.
        skip_modules (`List[str]`, *optional*, defaults to `lm_head`):
            List of modules names not to convert. Defaults to `lm_head`.
    """
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)

        if isinstance(module, torch.nn.Linear) and name not in skip_modules:
            if issubclass(linear_replacement, Linear4bit):
                model._modules[name] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs
                )
            elif issubclass(linear_replacement, HQQLinear):
                model._modules[name] = linear_replacement(module, quant_config, **kwargs)
            else:
                raise ValueError(f"Unsupported linear replacement: {type(linear_replacement)}")
    return model


In [None]:
def replace_linear(model:nn.Module, linear_replacement:nn.Module, quant_config:Union[dict,None]=None,
                  skip_modules:List[str]=['lm_head'], **kwargs):
    for name, module in model.named_children():
        if len(list(module.children())) > 0:
            replace_linear(module, linear_replacement, quant_config, skip_modules, **kwargs)
            
        
        if isinstance(module, nn.Linear) and name not in skip_modules:
            if issubclass(linear_replacement, Linear4bit):
                module._modules[name] = linear_replacement(
                    module.in_features,
                    module.out_features,
                    module.bias is not None,
                    **kwargs
                )
            elif issubclass(linear_replacement, HQQLinear):
                model._modules[name] = linear_replacement(module, quant_config, **kwargs)
            else:
                raise ValueError(f"unsupported linear replacement: {type(linear_replacement)}")
    return model

In [None]:
m = nn.Sequential(
    nn.Linear(5, 10),
    nn.Sequential(nn.Linear(10, 50), nn.ReLU())
)
m

Sequential(
  (0): Linear(in_features=5, out_features=10, bias=True)
  (1): Sequential(
    (0): Linear(in_features=10, out_features=50, bias=True)
    (1): ReLU()
  )
)

In [None]:
for name, module in m.named_children():
    print(name)
    for l in module.children():
        print(l)
    print(module._modules)

0
OrderedDict()
1
Linear(in_features=10, out_features=50, bias=True)
ReLU()
OrderedDict([('0', Linear(in_features=10, out_features=50, bias=True)), ('1', ReLU())])
