In [1]:
%%writefile lora_layer.py

import math
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F

# The base class for LoRA layers
class LoraLayer:
    def __init__(
        self,
        in_features: int,  # The number of input features
        out_features: int,  # The number of output features
    ):
        # Initializes dictionaries to store various parameters for each adapter in the layer
        self.r = {}  # The rank of the low-rank matrix
        self.lora_alpha = {}  # The scaling factor
        self.scaling = {}  # The calculated scaling factor (lora_alpha / r)

        # Dropout layers for each adapter
        self.lora_dropout = nn.ModuleDict({})

        # Weight matrices for the linear layers
        self.lora_A = nn.ModuleDict({})
        self.lora_B = nn.ModuleDict({})

        # Weight matrices for the embedding layers
        self.lora_embedding_A = nn.ParameterDict({})
        self.lora_embedding_B = nn.ParameterDict({})

        # Boolean flag indicating whether the weights have been merged
        self.merged = False

        # Boolean flag indicating whether the adapters are disabled
        self.disable_adapters = False

        # Stores the number of input and output features
        self.in_features = in_features
        self.out_features = out_features
    
    # Method to update the parameters of the layer with a new adapter
    def update_layer(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        # Updates the rank and scaling factor for the adapter
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha

        # If dropout rate is greater than 0, creates a dropout layer, otherwise creates an identity layer
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        # Updates the dropout layer for the adapter
        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))

        # If rank is greater than 0, creates trainable parameters for the adapter
        if r > 0:
            self.lora_A.update(nn.ModuleDict({adapter_name: nn.Linear(self.in_features, r, bias=False)}))
            self.lora_B.update(nn.ModuleDict({adapter_name: nn.Linear(r, self.out_features, bias=False)}))
            self.scaling[adapter_name] = lora_alpha / r

        # If init_lora_weights is True, resets the parameters of the adapter
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name)

        # Moves the layer to the same device as the weight tensor
        self.to(self.weight.device)

     # Method to update the parameters of the embedding layer with a new adapter
    def update_layer_embedding(self, adapter_name, r, lora_alpha, lora_dropout, init_lora_weights):
        # Updates the rank and scaling factor for the adapter
        self.r[adapter_name] = r
        self.lora_alpha[adapter_name] = lora_alpha

        # If dropout rate is greater than 0, creates a dropout layer, otherwise creates an identity layer
        if lora_dropout > 0.0:
            lora_dropout_layer = nn.Dropout(p=lora_dropout)
        else:
            lora_dropout_layer = nn.Identity()

        # Updates the dropout layer for the adapter
        self.lora_dropout.update(nn.ModuleDict({adapter_name: lora_dropout_layer}))

        # If rank is greater than 0, creates trainable parameters for the adapter
        if r > 0:
            self.lora_embedding_A.update(
                nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((r, self.in_features)))})
            )
            self.lora_embedding_B.update(
                nn.ParameterDict({adapter_name: nn.Parameter(self.weight.new_zeros((self.out_features, r)))})
            )
            self.scaling[adapter_name] = lora_alpha / r

        # If init_lora_weights is True, resets the parameters of the adapter
        if init_lora_weights:
            self.reset_lora_parameters(adapter_name)

        # Moves the layer to the same device as the weight tensor
        self.to(self.weight.device)

    # Method to reset the parameters of an adapter
    def reset_lora_parameters(self, adapter_name):
        if adapter_name in self.lora_A.keys():
            # initialize A the same way as the default for nn.Linear and B to zero
            nn.init.kaiming_uniform_(self.lora_A[adapter_name].weight, a=math.sqrt(5))
            nn.init.zeros_(self.lora_B[adapter_name].weight)
        if adapter_name in self.lora_embedding_A.keys():
            # initialize a the same way as the default for nn.linear and b to zero
            nn.init.zeros_(self.lora_embedding_A[adapter_name])
            nn.init.normal_(self.lora_embedding_B[adapter_name])

# LoRA implemented in an Embedding layer
class Embedding(nn.Embedding, LoraLayer):
    """
    The Embedding class is an extension of the PyTorch nn.Embedding class 
    and LoraLayer class to incorporate the LoRA method.
    """
    def __init__(
        self,
        adapter_name: str,
        num_embeddings: int,
        embedding_dim: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        **kwargs,
    ):
        # Pop the init_lora_weights flag from kwargs
        init_lora_weights = kwargs.pop("init_lora_weights", True)

        # Call the constructors of the parent classes
        nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs)
        LoraLayer.__init__(self, in_features=num_embeddings, out_features=embedding_dim)

        # Freezing the pre-trained weight matrix
        self.weight.requires_grad = False

        # Reset the parameters of the Embedding layer and update it with the adapter
        nn.Embedding.reset_parameters(self)
        self.update_layer_embedding(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)

        # Set the active adapter
        self.active_adapter = adapter_name

    # Separate low-rank approximation from original weight
    def unmerge(self, mode: bool = True):
        # If the weights are already unmerged, raise a warning
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        # If the rank of the active adapter is greater than 0, subtract the product of the LoRA weights
        # from the weights of the embedding
        if self.r[self.active_adapter] > 0:
            self.weight.data -= (
                transpose(
                    self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = False

    # Merge low-rank approximation with original weights
    def merge(self):
        # If the weights are already merged, raise a warning
        if self.merged:
            warnings.warn("Already merged. Nothing to do.")
            return
        # If the rank of the active adapter is greater than 0, add the product of the LoRA weights
        # to the weights of the embedding
        if self.r[self.active_adapter] > 0:
            self.weight.data += (
                transpose(
                    self.lora_embedding_B[self.active_adapter] @ self.lora_embedding_A[self.active_adapter], True
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = True

    # Defines the computation performed at every call.
    def forward(self, x: torch.Tensor):
        # If adapters are disabled and there is an active adapter with rank > 0 and it is merged
        # Subtract the LoRA weights from the original weights and set merged to False
        if self.disable_adapters:
            if self.r[self.active.adapter] > 0 and self.merged:
                self.weight.data -= (
                    transpose(
                        self.lora_embedding_B[self.active_adapter].weight
                        @ self.lora_embedding_A[self.active_adapter].weight,
                        True,
                    )
                    * self.scaling[self.active_adapter]
                )
                self.merged = False
            # Forward pass with the original weights
            return nn.Embedding.forward(self, x)

        # If there is an active adapter with rank > 0 and it is not merged
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = nn.Embedding.forward(self, x)
            # Compute the forward pass with the LoRA weights and add it to the result
            if self.r[self.active_adapter] > 0:
                after_A = F.embedding(
                    x,
                    self.lora_embedding_A[self.active_adapter].T,
                    self.padding_idx,
                    self.max_norm,
                    self.norm_type,
                    self.scale_grad_by_freq,
                    self.sparse,
                )
                result += (after_A @ self.lora_embedding_B[self.active_adapter].T) * self.scaling[self.active_adapter]
            return result
        else:
            return nn.Embedding.forward(self, x)


# Lora is implemented in a dense (Linear) layer
class Linear(nn.Linear, LoraLayer):
    
    def __init__(
        self,
        adapter_name: str,
        in_features: int,
        out_features: int,
        r: int = 0,
        lora_alpha: int = 1,
        lora_dropout: float = 0.0,
        fan_in_fan_out: bool = False,  # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
        **kwargs,
    ):
        # Initialize weights for LoRA layer
        init_lora_weights = kwargs.pop("init_lora_weights", True)

        # Initialize linear and LoRA layers
        nn.Linear.__init__(self, in_features, out_features, **kwargs)
        LoraLayer.__init__(self, in_features=in_features, out_features=out_features)

        # Freezing the pre-trained weight matrix
        self.weight.requires_grad = False

        # Transpose the weight if the layer to replace stores weight like (fan_in, fan_out)
        self.fan_in_fan_out = fan_in_fan_out
        if fan_in_fan_out:
            self.weight.data = self.weight.data.T

        # Reset linear layer parameters and update LoRA layer
        nn.Linear.reset_parameters(self)
        self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
        self.active_adapter = adapter_name

    def merge(self):
        # Merge low-rank approximation with original weights
        if self.active_adapter not in self.lora_A.keys():
            return
        if self.merged:
            warnings.warn("Already merged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:

            # TODO: Merge the LoRA parameters by adding the product of lora_B weights and lora_A weights (after transposing 
            # if necessary) to the original weights, scaled by the LoRA scaling factor. After this operation, set the merged
            # flag to True.
            
            ### YOUR CODE HERE ###
            self.weight.data +=  (
                transpose(
                    self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
                    self.fan_in_fan_out,
                )
                * self.scaling[self.active_adapter]
            )
            
            ### YOUR CODE HERE ###
            self.merged = True

    def unmerge(self):
        # Separate low-rank approximation from original weights
        if self.active_adapter not in self.lora_A.keys():
            return
        if not self.merged:
            warnings.warn("Already unmerged. Nothing to do.")
            return
        if self.r[self.active_adapter] > 0:
            self.weight.data -= (
                transpose(
                    self.lora_B[self.active_adapter].weight @ self.lora_A[self.active_adapter].weight,
                    self.fan_in_fan_out,
                )
                * self.scaling[self.active_adapter]
            )
            self.merged = False

    def forward(self, x: torch.Tensor):
        previous_dtype = x.dtype
        if self.active_adapter not in self.lora_A.keys():
            return F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        if self.disable_adapters:
            if self.r[self.active_adapter] > 0 and self.merged:
                self.unmerge()
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        elif self.r[self.active_adapter] > 0 and not self.merged:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
            # Changing data type for ensuring consistency
            x = x.to(self.lora_A[self.active_adapter].weight.dtype)
            
            # TODO: If the LoRA adapter is active and not merged, add the output of the LoRA layers to the result. This involves
            # passing the input through lora_A, applying dropout, then passing it through lora_B. The output is scaled by the
            # LoRA scaling factor and added to the result.
            
            ### YOUR CODE HERE ###
            result += (
                self.lora_B[self.active_adapter](
                    self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                )
                * self.scaling[self.active_adapter]
            )
        
        else:
            result = F.linear(x, transpose(self.weight, self.fan_in_fan_out), bias=self.bias)
        
        # Reverting to the previous data type
        result = result.to(previous_dtype)
        return result
    
def transpose(weight, fan_in_fan_out):
    # Helper function to transpose weights if required
    return weight.T if fan_in_fan_out else weight



Overwriting lora_layer.py


In [2]:
%%writefile lora_model.py

import importlib
import inspect
import sys, os

from contextlib import contextmanager

import math
import re
import copy
import warnings
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import List, Optional, Union

import torch
import torch.nn as nn
import torch.nn.functional as F

from peft import (
    PeftConfig,
    LoraConfig,
    get_peft_model_state_dict,
) 

from transformers.pytorch_utils import Conv1D
from transformers.utils import PushToHubMixin

from accelerate import dispatch_model, infer_auto_device_map
from accelerate.hooks import (
    AlignDevicesHook,
    add_hook_to_module,
    remove_hook_from_submodules,
)
from accelerate.utils import get_balanced_memory

import bitsandbytes as bnb

from lora_layer import LoraLayer, Embedding, Linear

class LoraModel(torch.nn.Module):

    def __init__(self, model, config, adapter_name='default'):
        super().__init__()
        self.model = model
        self.forward = self.model.forward
        self.peft_config = config
        self.add_adapter(adapter_name, self.peft_config[adapter_name])

    def add_adapter(self, adapter_name, config=None):
        if config is not None:
            model_config = self.model.config.to_dict() if hasattr(self.model.config, "to_dict") else self.model.config
            config = self._prepare_lora_config(config, model_config)
            self.peft_config[adapter_name] = config
        self._find_and_replace(adapter_name)
        if len(self.peft_config) > 1 and self.peft_config[adapter_name].bias != "none":
            raise ValueError(
                "LoraModel supports only 1 adapter with bias. When using multiple adapters, set bias to 'none' for all adapters."
            )
        mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
        if self.peft_config[adapter_name].inference_mode:
            _freeze_adapter(self.model, adapter_name)

    def _find_and_replace(self, adapter_name):
        lora_config = self.peft_config[adapter_name]
        loaded_in_4bit = getattr(self.model, "is_loaded_in_4bit", False)
        loaded_in_8bit = getattr(self.model, "is_loaded_in_8bit", False)
        if (loaded_in_4bit or loaded_in_8bit) and not is_bnb_available():
            raise ImportError(
                "To use Lora with 8-bit or 4-bit quantization, please install the `bitsandbytes` package. "
                "You can install it with `pip install bitsandbytes`."
            )
        is_target_modules_in_base_model = False
        kwargs = {
            "r": lora_config.r,
            "lora_alpha": lora_config.lora_alpha,
            "lora_dropout": lora_config.lora_dropout,
            "fan_in_fan_out": lora_config.fan_in_fan_out,
            "init_lora_weights": lora_config.init_lora_weights,
        }
        key_list = [key for key, _ in self.model.named_modules()]
        for key in key_list:
            if isinstance(lora_config.target_modules, str):
                target_module_found = re.fullmatch(lora_config.target_modules, key)
            else:
                target_module_found = any(key.endswith(target_key) for target_key in lora_config.target_modules)
            if target_module_found:
                if not is_target_modules_in_base_model:
                    is_target_modules_in_base_model = True
                parent, target, target_name = _get_submodules(self.model, key)
                if hasattr(target, "bias"):
                    bias = target.bias is not None

                if isinstance(target, LoraLayer):
                    target.update_layer(
                        adapter_name,
                        lora_config.r,
                        lora_config.lora_alpha,
                        lora_config.lora_dropout,
                        lora_config.init_lora_weights,
                    )
                else:
                    if loaded_in_8bit and isinstance(target, bnb.nn.Linear8bitLt):
                        eightbit_kwargs = kwargs.copy()
                        eightbit_kwargs.update(
                            {
                                "has_fp16_weights": target.state.has_fp16_weights,
                                "memory_efficient_backward": target.state.memory_efficient_backward,
                                "threshold": target.state.threshold,
                                "index": target.index,
                            }
                        )
                        new_module = Linear8bitLt(
                            adapter_name, target.in_features, target.out_features, bias=bias, **eightbit_kwargs
                        )
                        
                    elif isinstance(target, torch.nn.Embedding):
                        embedding_kwargs = kwargs.copy()
                        embedding_kwargs.pop("fan_in_fan_out", None)
                        in_features, out_features = target.num_embeddings, target.embedding_dim
                        new_module = Embedding(adapter_name, in_features, out_features, **embedding_kwargs)
                    else:
                        if isinstance(target, torch.nn.Linear):
                            in_features, out_features = target.in_features, target.out_features
                            if kwargs["fan_in_fan_out"]:
                                warnings.warn(
                                    "fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
                                    "Setting fan_in_fan_out to False."
                                )
                                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = False
                        elif isinstance(target, Conv1D):
                            in_features, out_features = (
                                target.weight.ds_shape if hasattr(target.weight, "ds_shape") else target.weight.shape
                            )
                            if not kwargs["fan_in_fan_out"]:
                                warnings.warn(
                                    "fan_in_fan_out is set to False but the target module is `Conv1D`. "
                                    "Setting fan_in_fan_out to True."
                                )
                                kwargs["fan_in_fan_out"] = lora_config.fan_in_fan_out = True
                        else:
                            raise ValueError(
                                f"Target module {target} is not supported. "
                                f"Currently, only `torch.nn.Linear` and `Conv1D` are supported."
                            )
                        new_module = Linear(adapter_name, in_features, out_features, bias=bias, **kwargs)

                    self._replace_module(parent, target_name, new_module, target)
        if not is_target_modules_in_base_model:
            raise ValueError(
                f"Target modules {lora_config.target_modules} not found in the base model. "
                f"Please check the target modules and try again."
            )

    def _replace_module(self, parent_module, child_name, new_module, old_module):
        setattr(parent_module, child_name, new_module)
        new_module.weight = old_module.weight
        if hasattr(old_module, "bias"):
            if old_module.bias is not None:
                new_module.bias = old_module.bias

        if getattr(old_module, "state", None) is not None:
            new_module.state = old_module.state
            new_module.to(old_module.weight.device)

        # dispatch to correct device
        for name, module in new_module.named_modules():
            if "lora_" in name:
                module.to(old_module.weight.device)

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.model, name)

    def get_peft_config_as_dict(self, inference: bool = False):
        config_dict = {}
        for key, value in self.peft_config.items():
            config = {k: v.value if isinstance(v, Enum) else v for k, v in asdict(value).items()}
            if inference:
                config["inference_mode"] = True
        config_dict[key] = config
        return config

    def _set_adapter_layers(self, enabled=True):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.disable_adapters = False if enabled else True

    def enable_adapter_layers(self):
        self._set_adapter_layers(enabled=True)

    def disable_adapter_layers(self):
        self._set_adapter_layers(enabled=False)

    def set_adapter(self, adapter_name):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                if module.merged:
                    warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.")
                    module.unmerge()
                module.active_adapter = adapter_name

    def merge_adapter(self):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.merge()

    def unmerge_adapter(self):
        for module in self.model.modules():
            if isinstance(module, LoraLayer):
                module.unmerge()

    @staticmethod
    def _prepare_lora_config(peft_config, model_config):
        if peft_config.target_modules is None:
            if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING:
                raise ValueError("Please specify `target_modules` in `peft_config`")
            peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]]
        if peft_config.inference_mode:
            peft_config.merge_weights = True
        return peft_config

    def merge_and_unload(self):
        r"""
        This method merges the LoRa layers into the base model. This is needed if someone wants to use the base model
        as a standalone model.
        """
        if getattr(self.config, "model_type", None) == "gpt2":
            raise ValueError("GPT2 models are not supported for merging LORA layers")

        if getattr(self.model, "is_loaded_in_8bit", False) or getattr(self.model, "is_loaded_in_4bit", False):
            raise ValueError("Cannot merge LORA layers when the model is loaded in 8-bit mode")

        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        for key in key_list:
            try:
                parent, target, target_name = _get_submodules(self.model, key)
            except AttributeError:
                continue
            if isinstance(target, LoraLayer):
                bias = target.bias is not None
                new_module = torch.nn.Linear(target.in_features, target.out_features, bias=bias)
                target.merge()
                self._replace_module(parent, target_name, new_module, target)

            # save any additional trainable modules part of `modules_to_save`
            if isinstance(target, ModulesToSaveWrapper):
                setattr(parent, target_name, target.modules_to_save[target.active_adapter])

        return self.model

    def add_weighted_adapter(self, adapters, weights, adapter_name):
        if len({self.peft_config[adapter].r for adapter in adapters}) != 1:
            raise ValueError("All adapters must have the same r value")
        self.peft_config[adapter_name] = self.peft_config[adapters[0]]
        self.peft_config[adapter_name].lora_alpha = self.peft_config[adapters[0]].r
        self._find_and_replace(adapter_name)
        mark_only_lora_as_trainable(self.model, self.peft_config[adapter_name].bias)
        _freeze_adapter(self.model, adapter_name)
        key_list = [key for key, _ in self.model.named_modules() if "lora" not in key]
        for key in key_list:
            _, target, _ = _get_submodules(self.model, key)
            if isinstance(target, LoraLayer):
                if adapter_name in target.lora_A:
                    target.lora_A[adapter_name].weight.data = target.lora_A[adapter_name].weight.data * 0.0
                    target.lora_B[adapter_name].weight.data = target.lora_B[adapter_name].weight.data * 0.0
                    for adapter, weight in zip(adapters, weights):
                        if adapter not in target.lora_A:
                            continue
                        target.lora_A[adapter_name].weight.data += (
                            target.lora_A[adapter].weight.data * weight * target.scaling[adapter]
                        )
                        target.lora_B[adapter_name].weight.data += target.lora_B[adapter].weight.data * weight

                elif adapter_name in target.lora_embedding_A:
                    target.lora_embedding_A[adapter_name].data = target.lora_embedding_A[adapter_name].data * 0.0
                    target.lora_embedding_B[adapter_name].data = target.lora_embedding_B[adapter_name].data * 0.0
                    for adapter, weight in zip(adapters, weights):
                        if adapter not in target.lora_embedding_A:
                            continue
                        target.lora_embedding_A[adapter_name].data += (
                            target.lora_embedding_A[adapter].data * weight * target.scaling[adapter]
                        )
                        target.lora_embedding_B[adapter_name].data += target.lora_embedding_B[adapter].data * weight


class LoraModelForCasualLM(PushToHubMixin, torch.nn.Module):

    def __init__(self, model, peft_config: PeftConfig, adapter_name="default"):
        super().__init__()
        self.base_model = model
        self.config = self.base_model.config
        self.modules_to_save = None
        self.peft_config = {}
        self.active_adapter = adapter_name
        self.peft_type = peft_config.peft_type
        self.base_model_torch_dtype = getattr(model, "dtype", None)
        self.peft_config[adapter_name] = peft_config
        self.base_model = LoraModel(self.base_model, self.peft_config, adapter_name)
        self.set_additional_trainable_modules(peft_config, adapter_name)
        
        self.base_model_prepare_inputs_for_generation = self.base_model.prepare_inputs_for_generation

    def save_pretrained(self, save_directory, **kwargs):
        if os.path.isfile(save_directory):
            raise ValueError(f"Provided path ({save_directory}) should be a directory, not a file")
        os.makedirs(save_directory, exist_ok=True)

        for adapter_name, peft_config in self.peft_config.items():
            # save only the trainable weights
            output_state_dict = get_peft_model_state_dict(
                self, state_dict=kwargs.get("state_dict", None), adapter_name=adapter_name
            )
            output_dir = os.path.join(save_directory, adapter_name) if adapter_name != "default" else save_directory
            os.makedirs(output_dir, exist_ok=True)
            torch.save(output_state_dict, os.path.join(output_dir, "adapter_model.bin"))

            # save the config and change the inference mode to `True`
            if peft_config.base_model_name_or_path is None:
                peft_config.base_model_name_or_path = self.base_model.model.__dict__.get("name_or_path", None)
            inference_mode = peft_config.inference_mode
            peft_config.inference_mode = True
            peft_config.save_pretrained(output_dir)
            peft_config.inference_mode = inference_mode

    @classmethod
    def from_pretrained(cls, model, model_id, adapter_name="default", is_trainable=False, **kwargs):
        # load the config
        config = LoraConfig.from_pretrained(model_id, subfolder=kwargs.get("subfolder", None))

        if (getattr(model, "hf_device_map", None) is not None) and len(
            set(model.hf_device_map.values()).intersection({"cpu", "disk"})
        ) > 0:
            remove_hook_from_submodules(model)


        config.inference_mode = not is_trainable

        model = LoraModelForCasualLM(model, config, adapter_name)
        model.load_adapter(model_id, adapter_name, **kwargs)
        return model

    def print_trainable_parameters(self):
        """
        Prints the number of trainable parameters in the model.
        """
        trainable_params = 0
        all_param = 0
        for _, param in self.named_parameters():
            num_params = param.numel()
            # if using DS Zero 3 and the weights are initialized empty
            if num_params == 0 and hasattr(param, "ds_numel"):
                num_params = param.ds_numel

            all_param += num_params
            if param.requires_grad:
                trainable_params += num_params
        print(
            f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
        )

    def __getattr__(self, name: str):
        """Forward missing attributes to the wrapped module."""
        try:
            return super().__getattr__(name)  # defer to nn.Module's logic
        except AttributeError:
            return getattr(self.base_model, name)

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
        **kwargs,
    ):
        return self.base_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            inputs_embeds=inputs_embeds,
            labels=labels,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            **kwargs,
        )

    @contextmanager
    def disable_adapter(self):
        """
        Disables the adapter module.
        """
        try:
            self.base_model.disable_adapter_layers()
        finally:
            self.base_model.enable_adapter_layers()

    def get_base_model(self):
        """
        Returns the base model.
        """
        return self.base_model.model

    def add_adapter(self, adapter_name, peft_config):
        if peft_config.peft_type != self.peft_type:
            raise ValueError(
                f"Cannot combine adapters with different peft types. "
                f"Found {self.peft_type} and {peft_config.peft_type}."
            )
        self.peft_config[adapter_name] = peft_config
        self.base_model.add_adapter(adapter_name, peft_config)
        self.set_additional_trainable_modules(peft_config, adapter_name)

    def set_additional_trainable_modules(self, peft_config, adapter_name):
        if getattr(peft_config, "modules_to_save", None) is not None:
            if self.modules_to_save is None:
                self.modules_to_save = set(peft_config.modules_to_save)
            else:
                self.modules_to_save.update(peft_config.modules_to_save)
            _set_trainable(self, adapter_name)
            
    def generate(self, **kwargs):
        peft_config = self.active_peft_config
        self.base_model.prepare_inputs_for_generation = self.prepare_inputs_for_generation
        if hasattr(self.base_model, "model"):
            self.base_model.model.generation_config = self.generation_config
        else:
            self.base_model.generation_config = self.generation_config
        try:
            outputs = self.base_model.generate(**kwargs)
        except:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            raise
        else:
            self.base_model.prepare_inputs_for_generation = self.base_model_prepare_inputs_for_generation
            return outputs

    def prepare_inputs_for_generation(self, *args, **kwargs):
        peft_config = self.active_peft_config
        model_kwargs = self.base_model_prepare_inputs_for_generation(*args, **kwargs)
        if model_kwargs["past_key_values"] is None:
            inputs_embeds = self.word_embeddings(model_kwargs["input_ids"])
            prompts = self.get_prompt(batch_size=model_kwargs["input_ids"].shape[0])
            prompts = prompts.to(inputs_embeds.dtype)
            model_kwargs["inputs_embeds"] = torch.cat((prompts, inputs_embeds), dim=1)
            model_kwargs["input_ids"] = None

        return model_kwargs
    
    def load_adapter(self, model_id, adapter_name, is_trainable=False, **kwargs):
        if adapter_name not in self.peft_config:
            # load the config
            peft_config = LoraConfig.from_pretrained(model_id, subfolder=kwargs.get("subfolder", None))
            peft_config.inference_mode = not is_trainable
            self.add_adapter(adapter_name, peft_config)

        # load weights if any
        path = os.path.join(model_id, kwargs["subfolder"]) if kwargs.get("subfolder", None) is not None else model_id

        if os.path.exists(os.path.join(path, "adapter_model.bin")):
            filename = os.path.join(path, "adapter_model.bin")
        else:
            raise ValueError(
                f"Can't find weights for {model_id} in {model_id} or in the Hugging Face Hub. "
                f"Please check that the file adapter.bin is present at {model_id}."
            )

        adapters_weights = torch.load(
            filename, map_location=torch.device("cuda" if torch.cuda.is_available() else "cpu")
        )
        # load the weights into the model
        set_peft_model_state_dict(self, adapters_weights, adapter_name=adapter_name)
        if (
            (getattr(self, "hf_device_map", None) is not None)
            and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
            and len(self.peft_config) == 1
        ):
            device_map = kwargs.get("device_map", "auto")
            max_memory = kwargs.get("max_memory", None)
            offload_dir = kwargs.get("offload_folder", None)
            offload_index = kwargs.get("offload_index", None)

            dispatch_model_kwargs = {}
            # Safety checker for previous `accelerate` versions
            # `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
            if "offload_index" in inspect.signature(dispatch_model).parameters:
                dispatch_model_kwargs["offload_index"] = offload_index

            no_split_module_classes = self._no_split_modules

            if device_map != "sequential":
                max_memory = get_balanced_memory(
                    self,
                    max_memory=max_memory,
                    no_split_module_classes=no_split_module_classes,
                    low_zero=(device_map == "balanced_low_0"),
                )
            if isinstance(device_map, str):
                device_map = infer_auto_device_map(
                    self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
                )
            dispatch_model(
                self,
                device_map=device_map,
                offload_dir=offload_dir,
                **dispatch_model_kwargs,
            )
            hook = AlignDevicesHook(io_same_device=True)
            add_hook_to_module(self.get_base_model(), hook)

        # Set model in evaluation mode to deactivate Dropout modules by default
        self.eval()

    def set_adapter(self, adapter_name):
        """
        Sets the active adapter.
        """
        if adapter_name not in self.peft_config:
            raise ValueError(f"Adapter {adapter_name} not found.")
        self.active_adapter = adapter_name
        self.base_model.set_adapter(adapter_name)
        _set_adapter(self, adapter_name)
    
    @property
    def active_peft_config(self):
        return self.peft_config[self.active_adapter]

class Linear8bitLt(bnb.nn.Linear8bitLt, LoraLayer):
        # Lora implemented in a dense layer
        def __init__(
            self,
            adapter_name,
            in_features,
            out_features,
            r: int = 0,
            lora_alpha: int = 1,
            lora_dropout: float = 0.0,
            **kwargs,
        ):
            bnb.nn.Linear8bitLt.__init__(
                self,
                in_features,
                out_features,
                bias=kwargs.get("bias", True),
                has_fp16_weights=kwargs.get("has_fp16_weights", True),
                memory_efficient_backward=kwargs.get("memory_efficient_backward", False),
                threshold=kwargs.get("threshold", 0.0),
                index=kwargs.get("index", None),
            )
            LoraLayer.__init__(self, in_features=in_features, out_features=out_features)

            # Freezing the pre-trained weight matrix
            self.weight.requires_grad = False
            init_lora_weights = kwargs.pop("init_lora_weights", True)
            self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights)
            self.active_adapter = adapter_name

        def forward(self, x: torch.Tensor):
            result = super().forward(x)

            if self.disable_adapters or self.active_adapter not in self.lora_A.keys():
                return result
            elif self.r[self.active_adapter] > 0:
                if not torch.is_autocast_enabled():
                    expected_dtype = result.dtype

                    if x.dtype != torch.float32:
                        x = x.float()
                    output = (
                        self.lora_B[self.active_adapter](
                            self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                        ).to(expected_dtype)
                        * self.scaling[self.active_adapter]
                    )
                else:
                    output = (
                        self.lora_B[self.active_adapter](
                            self.lora_A[self.active_adapter](self.lora_dropout[self.active_adapter](x))
                        )
                        * self.scaling[self.active_adapter]
                    )
                result += output
            return result

class ModulesToSaveWrapper(torch.nn.Module):
    def __init__(self, module_to_save, adapter_name):
        super().__init__()
        self.original_module = module_to_save
        self.modules_to_save = torch.nn.ModuleDict({})
        self.update(adapter_name)
        self.active_adapter = adapter_name

    def update(self, adapter_name):
        self.modules_to_save.update(torch.nn.ModuleDict({adapter_name: copy.deepcopy(self.original_module)}))

    def forward(self, *args, **kwargs):
        if self.active_adapter not in self.modules_to_save:
            return self.original_module(*args, **kwargs)
        return self.modules_to_save[self.active_adapter](*args, **kwargs)

def mark_only_lora_as_trainable(model: nn.Module, bias: str = "none") -> None:
    for n, p in model.named_parameters():
        if "lora_" not in n:
            p.requires_grad = False
    if bias == "none":
        return
    elif bias == "all":
        for n, p in model.named_parameters():
            if "bias" in n:
                p.requires_grad = True
    elif bias == "lora_only":
        for m in model.modules():
            if isinstance(m, LoraLayer) and hasattr(m, "bias") and m.bias is not None:
                m.bias.requires_grad = True
    else:
        raise NotImplementedError

def _freeze_adapter(model, adapter_name):
    for n, p in model.named_parameters():
        if adapter_name in n:
            p.requires_grad = False
            
def _get_submodules(model, key):
    parent = model.get_submodule(".".join(key.split(".")[:-1]))
    target_name = key.split(".")[-1]
    target = model.get_submodule(key)
    return parent, target, target_name

def is_bnb_available():
    return importlib.util.find_spec("bitsandbytes") is not None



def _set_trainable(model, adapter_name):
    key_list = [key for key, _ in model.named_modules()]
    for key in key_list:
        target_module_found = any(key.endswith(target_key) for target_key in model.modules_to_save)
        if target_module_found:
            parent, target, target_name = _get_submodules(model, key)
            if isinstance(target, ModulesToSaveWrapper):
                target.update(adapter_name)
            else:
                for param in target.parameters():
                    param.requires_grad = True
                setattr(parent, target_name, ModulesToSaveWrapper(target, adapter_name))

def get_peft_model_state_dict(model, state_dict=None, adapter_name="default"):
    config = model.peft_config[adapter_name]
    if state_dict is None:
        state_dict = model.state_dict()
    if config.peft_type in ("LORA"):
        bias = config.bias
        if bias == "none":
            to_return = {k: state_dict[k] for k in state_dict if "lora_" in k}
        elif bias == "all":
            to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k}
        elif bias == "lora_only":
            to_return = {}
            for k in state_dict:
                if "lora_" in k:
                    to_return[k] = state_dict[k]
                    bias_name = k.split("lora_")[0] + "bias"
                    if bias_name in state_dict:
                        to_return[bias_name] = state_dict[bias_name]
        else:
            raise NotImplementedError
        to_return = {k: v for k, v in to_return.items() if (("lora_" in k and adapter_name in k) or ("bias" in k))}
    else:
        raise NotImplementedError
    if model.modules_to_save is not None:
        for key, value in state_dict.items():
            if any(f"{module_name}.modules_to_save.{adapter_name}" in key for module_name in model.modules_to_save):
                to_return[key.replace("modules_to_save.", "")] = value

    to_return = {k.replace(f".{adapter_name}", ""): v for k, v in to_return.items()}
    return to_return


def set_peft_model_state_dict(model, peft_model_state_dict, adapter_name="default"):
    """
    Set the state dict of the Peft model.

    Args:
        model ([`PeftModel`]): The Peft model.
        peft_model_state_dict (`dict`): The state dict of the Peft model.
    """
    config = model.peft_config[adapter_name]
    state_dict = {}
    if model.modules_to_save is not None:
        for key, value in peft_model_state_dict.items():
            if any(module_name in key for module_name in model.modules_to_save):
                for module_name in model.modules_to_save:
                    if module_name in key:
                        key = key.replace(module_name, f"{module_name}.modules_to_save.{adapter_name}")
                        break
            state_dict[key] = value
    else:
        state_dict = peft_model_state_dict

    if config.peft_type in ("LORA"):
        peft_model_state_dict = {}
        for k, v in state_dict.items():
            if "lora_" in k:
                suffix = k.split("lora_")[1]
                if "." in suffix:
                    suffix_to_replace = ".".join(suffix.split(".")[1:])
                    k = k.replace(suffix_to_replace, f"{adapter_name}.{suffix_to_replace}")
                else:
                    k = f"{k}.{adapter_name}"
                peft_model_state_dict[k] = v
            else:
                peft_model_state_dict[k] = v
    else:
        raise NotImplementedError

    model.load_state_dict(peft_model_state_dict, strict=False)
    

def _set_adapter(model, adapter_name):
    for module in model.modules():
        if isinstance(module, ModulesToSaveWrapper):
            module.active_adapter = adapter_name
            
            

TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = {
    "t5": ["q", "v"],
    "mt5": ["q", "v"],
    "bart": ["q_proj", "v_proj"],
    "gpt2": ["c_attn"],
    "bloom": ["query_key_value"],
    "blip-2": ["q", "v", "q_proj", "v_proj"],
    "opt": ["q_proj", "v_proj"],
    "gptj": ["q_proj", "v_proj"],
    "gpt_neox": ["query_key_value"],
    "gpt_neo": ["q_proj", "v_proj"],
    "bert": ["query", "value"],
    "roberta": ["query", "value"],
    "xlm-roberta": ["query", "value"],
    "electra": ["query", "value"],
    "deberta-v2": ["query_proj", "value_proj"],
    "deberta": ["in_proj"],
    "layoutlm": ["query", "value"],
    "llama": ["q_proj", "v_proj"],
    "chatglm": ["query_key_value"],
}

Overwriting lora_model.py


In [3]:
%%writefile prompt.py

from typing import Union

class Prompter(object):
    __slots__ = ("template")

    def __init__(self, template_name: str = "", verbose: bool = False):
        self.template = {
            "description": "Template used by Alpaca-LoRA.",
            "prompt_input": "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n",
            "prompt_no_input": "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n",
            "response_split": "### Response:"}

    def generate_prompt(
        self,
        instruction: str,
        input: Union[None, str] = None,
        label: Union[None, str] = None,
    ) -> str:
        # returns the full prompt from instruction and optional input
        # if a label (=response, =output) is provided, it's also appended.
        if input:
            res = self.template["prompt_input"].format(
                instruction=instruction, input=input
            )
        else:
            res = self.template["prompt_no_input"].format(
                instruction=instruction
            )
        if label:
            res = f"{res}{label}"
        return res

    def get_response(self, output: str) -> str:
        return output.split(self.template["response_split"])[1].strip()



Writing prompt.py


In [4]:
%%writefile inference.py

import torch
from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer, AutoConfig

from lora_model import LoraModelForCasualLM
from prompt import Prompter

def get_response(prompt, tokenizer, model, generation_config, max_new_tokens):
    inputs = tokenizer(prompt, return_tensors="pt")
    output = model.generate(
                input_ids=inputs['input_ids'].cuda(),
                generation_config=generation_config,
                max_new_tokens=max_new_tokens,
                do_sample=True)
    output = tokenizer.batch_decode(output, skip_special_tokens=True)[0]
    return output

def generate_inference(instruction: str, user_inp: str, model_path:str, lora_weights_path:str):
    
    top_k = 40
    top_p = 128
    temperature = 0.1
    num_beams = 1
    max_new_tokens = 128

    tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
    
    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    architecture = config.architectures[0]
    if "Llama" in architecture:
        print("Setting EOS, BOS, UNK, and PAD tokens for LLama tokenizer")
        tokenizer.add_special_tokens(
            {
                "eos_token": "</s>",
                "bos_token": "</s>",
                "unk_token": "</s>",
            }
        )
        tokenizer.pad_token_id = (
            0  # unk. we want this to be different from the eos token
        )
        tokenizer.padding_side = "left"
    
    model = AutoModelForCausalLM.from_pretrained(
            model_path,
            device_map="auto",
            trust_remote_code=True)
    model = LoraModelForCasualLM.from_pretrained(
            model,
            lora_weights_path,
            torch_dtype=torch.float16,
            trust_remote_code=True)
        
    model.config.pad_token_id = tokenizer.pad_token_id
    model.config.bos_token_id = tokenizer.bos_token_id
    model.config.eos_token_id = tokenizer.eos_token_id
    model.eval()
    
    generation_config = GenerationConfig(
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            num_beams=num_beams)
    
    prompter = Prompter()
    
   
    if user_inp.lower().strip() == "n/a":
        user_inp = None
    prompt = prompter.generate_prompt(instruction, user_inp)
    output = get_response(prompt, tokenizer, model, generation_config, max_new_tokens)
    response = prompter.get_response(output)
    return response
    


Overwriting inference.py


In [5]:
%%writefile prepare_data.py

from prompt import Prompter
from datasets import load_dataset
import random

from typing import Union



def create_datasets(data_path, size_valid_set, tokenizer, max_length, seed):
    def tokenize(prompt, add_eos_token=True):
        result = tokenizer(
            prompt,
            truncation=True,
            max_length=max_length,
            padding=False,
            return_tensors=None
            )

        if (
            result["input_ids"][-1] != tokenizer.eos_token_id
            and len(result["input_ids"]) < max_length
            and add_eos_token
        ):
            
            result["input_ids"].append(tokenizer.eos_token_id)
            result["attention_mask"].append(1)

        result["labels"] = result["input_ids"].copy()
        return result

    
    def generate_and_tokenize_prompt(data_point):
        full_prompt = prompter.generate_prompt(
            data_point["instruction"],
            data_point["input"],
            data_point["output"],
        )
        tokenized_full_prompt = tokenize(full_prompt)

        return tokenized_full_prompt
    
    prompter = Prompter()

    print(f"Load dataset....")
    dataset = load_dataset('json', split='train', data_files=data_path)
    dataset = dataset.train_test_split(test_size=size_valid_set, seed=seed)

    train_data = dataset["train"].shuffle().map(generate_and_tokenize_prompt)
    valid_data = dataset["test"].map(generate_and_tokenize_prompt)
    
    train_data.set_format("torch")
    valid_data.set_format("torch")
    
    train_data = train_data.remove_columns(['instruction', 'input', 'output'])
    valid_data = valid_data.remove_columns(['instruction', 'input', 'output'])

    dataset["test"].to_json('dataset/val_data.json')
    
    return train_data, valid_data


Overwriting prepare_data.py


In [6]:
%%writefile common.py
import gdown
def download_from_driver(path, location_path):
    print(f"Begin download...., path: {path}")
    gdown.download(path, location_path, quiet=False, fuzzy=True)
    print(f"Completed download!!!: {path}")

Overwriting common.py


In [7]:
%%writefile logger_utils.py
import logging


class NoReceivedCommandFilter(logging.Filter):
    def filter(self, record):
        if 'Received command c' not in record.getMessage():
            return record.getMessage()


class NoPythonDotEnvFilter(logging.Filter):
    def filter(self, record):
        if 'Python-dotenv' not in record.getMessage():
            return record.getMessage()


def get_logger():
    logging.getLogger('py4j.java_gateway').setLevel(logging.ERROR)
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s', level=logging.INFO)
    logger = logging.getLogger(__name__)

    filter_1 = NoReceivedCommandFilter()
    filter_2 = NoPythonDotEnvFilter()
    logger.addFilter(filter_1)
    logger.addFilter(filter_2)

    return logger


Overwriting logger_utils.py


In [8]:
%%writefile train.py
import os
import torch
from tqdm import tqdm


from peft import LoraConfig, get_peft_model
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, DataCollatorForSeq2Seq

from contextlib import nullcontext

from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
import torch.distributed as dist
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader, SequentialSampler


from lora_model import LoraModelForCasualLM
from common import download_from_driver
from prepare_data import create_datasets

import warnings
warnings.filterwarnings('ignore')
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True


class Trainer:
    def __init__(self,
                 model,
                 tokenizer,
                 gpu_id: int,
                 is_ddp_training: bool = True,
                 output_dir: str = 'checkpoints/',
                 num_epochs: int = 10,
                 max_length: int = 128,
                 batch_size: int = 8,
                 mixed_precision_dtype=None,
                 gradient_accumulation_steps: int = 16):
        """
        Initialize the Trainer class.

        Args:
            model: Pretrained model object.
            tokenizer: Tokenizer object for text processing.
            num_epochs: Number of training epochs.
            max_length: Maximum sequence length.
            batch_size: Training batch size.
            gpu_id: GPU ID for training.
        """

        self.num_epochs = num_epochs
        self.max_length = max_length
        self.batch_size = batch_size
        self.output_dir = output_dir
        self.tokenizer = tokenizer
        self.is_ddp_training = is_ddp_training

        self.gpu_id = gpu_id
        self.model = model.to(f"cuda:{self.gpu_id}")
        self.gradient_accumulation_steps = gradient_accumulation_steps

        self.mixed_precision_dtype = mixed_precision_dtype
        self.ctx = None
        self.gradscaler = None

        # set mixed precision context
        self.set_mixed_precision_context(mixed_precision_dtype)

    def set_mixed_precision_context(self, mixed_precision_dtype):
        
        # TODO: Setup mixed precision training context

        if mixed_precision_dtype is None:
            
            # If 'mixed_precision_dtype' is None, use 'nullcontext',
            self.ctx = None

        else:
        
            # TODO Otherwise, use 'torch.amp.autocast' context with the specified dtype, and initialize GradScaler if mixed_precision_dtype is float16.
            self.ctx = None
            self.gradscaler = None

    def _set_ddp_training(self):

        # TODO: Initialize the DistributedDataParallel wrapper for the model.
        # You would need to pass the model and specify the device IDs
        # and output device for the data parallelism.

        ### YOUR CODE HERE ###
        self.model = DDP(self.model, device_ids=[self.gpu_id], output_device=self.gpu_id)
        ### YOUR CODE HERE ###

    def _run_batch(self, batch):
        """
        Run a single training batch.

        Args:
            batch: Batch data.

        Returns:
            Loss value for the batch.
        """

        with self.ctx:
            outputs = self.model(**batch)
            loss = outputs.loss / self.gradient_accumulation_steps  # Normalize loss
        loss_val = loss.item()

        # TODO: If 'mixed_precision_dtype' is torch.float16, you have to modify the backward using the gradscaler.
        if self.mixed_precision_dtype == torch.float16:

            ### YOUR CODE HERE ###
            self.gradscaler.scale(loss).backward()
            ### YOUR CODE HERE ###
                      
        else:
            loss.backward()

        return loss_val

    def _run_epoch(self, train_dataloader, epoch):
        """
        Run a single training epoch.

        Args:
            train_loader: Training data loader.
            epoch: Current epoch number.

        Returns:
            Total loss value for the epoch.
        """

        epoch_loss = 0
        self.model.train()

        if _is_master_process():
            train_progress_bar = tqdm(
                train_dataloader, desc=f"Epoch {epoch + 1} [Training]", position=0, leave=False)
        else:
            train_progress_bar = train_dataloader

        # Add counter for gradient accumulation
        steps = 0
        self.optimizer.zero_grad()  # Reset gradients at the beginning of each epoch
        for step, batch in enumerate(train_progress_bar):
            steps += 1
            batch = {key: value.to(self.gpu_id)
                     for key, value in batch.items()}
            loss = self._run_batch(batch)
            epoch_loss += loss

            # Perform optimizer step and reset gradients after accumulating enough gradients
            if steps % self.gradient_accumulation_steps == 0:

                # If 'mixed_precision_dtype' is torch.float16, you have to modify the gradient update step using the gradscaler.
                if self.mixed_precision_dtype == torch.float16:

                    ### YOUR CODE HERE ###
                    self.gradscaler.step(self.optimizer)
                    self.gradscaler.update()
                    ### YOUR CODE HERE ###
                    # TODO: optimizer step

                    # TODO: update scaler factor

                    
                else:
                    self.optimizer.step()
                self.optimizer.zero_grad()

                torch.cuda.empty_cache()
        epoch_loss /= (len(train_dataloader) /
                       self.gradient_accumulation_steps)
        return epoch_loss

    def _save_checkpoint(self, epoch):
        path_dir = f"{self.output_dir}/epoch_{epoch}"

        # check path_dir exited
        if not os.path.exists(path_dir):
            os.makedirs(path_dir)

        # save checkpoints
        if self.is_ddp_training and _is_master_process():
            self.model.module.save_pretrained(f'epoch_{epoch}_checkpoint')
        else:
            self.model.save_pretrained(f'epoch_{epoch}_checkpoint')

        print("Done saved at", f'epoch_{epoch}_checkpoint')

    def prepare_dataloader(self, train_dataset, eval_dataset):

        # TODO: Prepare the training DataLoader. Initialize 'DataLoader' with 'train_dataset'
        # and the appropriate 'batch_size'.
        # Depending on whether the training is distributed (is_ddp_training),
        # use 'DistributedSampler' for 'sampler' argument, else use 'None'.
        # Use 'DataCollatorForSeq2Seq' for 'collate_fn', passing 'tokenizer', padding settings and pad_to_multiple_of to 8, and return_tensors="pt"
        # Also add drop_last to True.

        ### YOUR CODE HERE ###

        # Initialize training DataLoader
        data_trainloader = DataLoader(
         train_dataset,
         batch_size=self.batch_size,
         sampler=DistributedSampler(train_dataset) if self.is_ddp_training else None, 
         collate_fn=DataCollatorForSeq2Seq(
             self.tokenizer,
             padding=True,
             pad_to_multiple_of=8,
             return_tensors="pt"
         ),
         drop_last=True
          )

        # TODO: Prepare the evaluation DataLoader. Initialize 'DataLoader' with 'eval_dataset',
        # the appropriate 'batch_size', and 'SequentialSampler' for 'sampler'.
        # Use 'DataCollatorForSeq2Seq' for 'collate_fn', passing 'tokenizer', padding settings and pad_to_multiple_of to 8, and return_tensors="pt".
        # Also add drop_last to True.

        ### YOUR CODE HERE ###

        # Initialize evaluation DataLoader
        data_testloader = DataLoader(
         eval_dataset,
         batch_size=self.batch_size,
         sampler=SequentialSampler(eval_dataset),
         collate_fn=DataCollatorForSeq2Seq(
             self.tokenizer,
             padding=True,
             pad_to_multiple_of=8,
             return_tensors="pt"
         ),
         drop_last=True
         )

        return data_trainloader, data_testloader

    def _eval(self, eval_dataloader, epoch: int):
        avg_loss = 0
        model.eval()
        if _is_master_process():
            eval_progress_bar = tqdm(
                eval_dataloader, desc=f"Epoch {epoch + 1} [Evaluation]", position=0, leave=False)
        else:
            eval_progress_bar = eval_dataloader


        for batch in eval_progress_bar:
            with self.ctx:
                with torch.no_grad():
                    if not self.is_ddp_training:
                        outputs = self.model(**batch.to(self.gpu_id))
                    else:
                        outputs = self.model(**batch)
            avg_loss += outputs.loss.item()
        avg_loss = avg_loss/(len(eval_dataloader))
        return avg_loss

    def run(self, data_path: str, size_valid_set: int = 0.25, seed: int = 123):
        """
        Run the training process.

        Returns:
            None
        """
        train_dataset, eval_dataset = create_datasets(
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            data_path=data_path,
            size_valid_set=size_valid_set,
            seed=seed
        )

        train_dataloader, eval_dataloader = self.prepare_dataloader(
            train_dataset, eval_dataset)

        if self.is_ddp_training:
            self._set_ddp_training()

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=learning_rate)

        for epoch in range(self.num_epochs):

            if self.is_ddp_training:
                train_dataloader.sampler.set_epoch(epoch)

            train_loss = self._run_epoch(train_dataloader, epoch)
            if self.is_ddp_training:
                dist.barrier() 
            if _is_master_process() or (epoch == self.num_epochs - 1):
                eval_loss = self._eval(
                    eval_dataloader=eval_dataloader, epoch=epoch)

                print(
                    f"epoch = {epoch+1} | avg_train_loss = {train_loss} | eval_loss = {eval_loss}")
            
            if _is_master_process():
                self._save_checkpoint(epoch=epoch+1)

def load_tokenizer_from_pretrained_model(model_path):

    config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
    architecture = config.architectures[0]
    tokenizer = AutoTokenizer.from_pretrained(
        model_path, trust_remote_code=True, device_map={"": torch.device(f"cuda:{0}")})
    tokenizer.pad_token = tokenizer.eos_token
    if _is_master_process():
        print('Completed to load config & tokenizer')

    if "Llama" in architecture:
        if _is_master_process():
            print("Setting EOS, BOS, UNK, and PAD tokens for LLama tokenizer")
        tokenizer.add_special_tokens(
            {
                "eos_token": "</s>",
                "bos_token": "</s>",
                "unk_token": "</s>",
            }
        )
        tokenizer.pad_token_id = (
            0  # unk. we want this to be different from the eos token
        )

    return tokenizer


def _is_master_process():
    ddp_rank = int(os.environ['RANK'])
    return ddp_rank == 0


def load_pretrained_model(local_rank, model_path: str = ""):
    # TODO: Load a pretrained AutoModelForCausalLM from the 'model_path'.
    # Make sure to set 'device_map' to '{"": torch.device(f"cuda:{local_rank}")}' for DDP training
    # and trust_remote_code=True.

    ### YOUR CODE HERE ###

    # Load pretrained model (AutoModelForCausalLM)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        trust_remote_code=True,
        device_map={"": torch.device(f"cuda:{local_rank}")}
    )

    # TODO: Create a LoraConfig with the parameters: 
    # r=4, lora_alpha=8, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", target_modules=['lm_head.linear', 'transformer.embd.wte'].
    # We will then use the config to initialize a LoraModelForCasualLM with the loaded model.

    ### YOUR CODE HERE ###

    # Create LoRA configuration
    lora_config = LoraConfig(
        r=4,
        lora_alpha=8,
        lora_dropout=0.05, 
        bias="none",    
        task_type="CAUSAL_LM",  
        target_modules=['lm_head.linear', 'transformer.embd.wte'] 
    )

    # TODO: Create LoRA model
    # Apply current model to Lora Model
    # Create LoRA model
    model = LoraModelForCasualLM(model, lora_config)  

    if _is_master_process():
        model.print_trainable_parameters()

    return model


if __name__ == "__main__":
    OUTPUT_DIR = "checkpoints/"

    backend = "nccl"
    model_path = "TheBloke/phi-2-GPTQ"
    if os.environ.get("DEBUG"):
        data_path = "/kaggle/input/cosodulieu/test_data.json"
    else:
        data_path = '/kaggle/input/cosodulieu/alpaca_data.json'

    size_valid_set = 0.15
    max_length = 128
    num_epochs = 3
    batch_size = 2
    gradient_accumulation_steps = 8

    learning_rate = 3e-4
    lr_scheduler_type = 'cosine'
    num_warmup_steps = 100
    weight_decay = 0.06

    seed = 0
    log_freq = 1
    eval_freq = 150

    distributed_strategy = "ddp" if os.environ.get("ON_DDP") else "no"

    if distributed_strategy == "ddp":

        # TODO: Initialize the process group for distributed data parallelism with nccl backend.
        # After that, you should set the 'local_rank' from the environment variable 'LOCAL_RANK'.

        # Initialize the process group

        ### YOUR CODE HERE ###
        init_process_group(backend="nccl", init_method="env://")
        local_rank = int(os.environ['LOCAL_RANK'])
        ### YOUR CODE HERE ###
    else:
        os.environ['RANK'] = '0'
        local_rank = 0

    # Prepare model
    model = load_pretrained_model(local_rank, model_path=model_path)
    
    # Get tokenizer
    tokenizer = load_tokenizer_from_pretrained_model(model_path=model_path)

    # prepare trainer
    trainer = Trainer(
        model=model,
        num_epochs=num_epochs,
        max_length=max_length,
        batch_size=batch_size,
        gpu_id=local_rank,
        
        mixed_precision_dtype=torch.float16 if os.environ.get("ON_MP") else None,
        
        tokenizer=tokenizer,
        output_dir=OUTPUT_DIR,
        is_ddp_training=True if distributed_strategy == "ddp" else False,
        gradient_accumulation_steps=gradient_accumulation_steps,
    )

    # set ddp for wraping model
    # execute trainer
    trainer.run(
        data_path=data_path,
        size_valid_set=size_valid_set,
        seed=seed
    )

    if distributed_strategy == "ddp":
        destroy_process_group()

Overwriting train.py


In [9]:
%%writefile requirements.txt

gdown
sentencepiece
transformers>=4.28.0
loralib
bitsandbytes 
appdirs
git+https://github.com/huggingface/accelerate.git
git+https://github.com/huggingface/datasets.git
einops
auto-gptq
optimum
git+https://github.com/huggingface/peft.git

Overwriting requirements.txt


In [10]:
!pip install -r requirements.txt

Collecting git+https://github.com/huggingface/accelerate.git (from -r requirements.txt (line 8))
  Cloning https://github.com/huggingface/accelerate.git to /tmp/pip-req-build-gl0xt4op
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/accelerate.git /tmp/pip-req-build-gl0xt4op
  Resolved https://github.com/huggingface/accelerate.git to commit 97d2168e5953fe7373a06c69c02c5a00a84d5344
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
[?25hCollecting git+https://github.com/huggingface/datasets.git (from -r requirements.txt (line 9))
  Cloning https://github.com/huggingface/datasets.git to /tmp/pip-req-build-1ht8_hga
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/datasets.git /tmp/pip-req-build-1ht8_hga
  Resolved https://github.com/huggingface/datasets.git to commit bdebf1922663c30744efb8869c86b28f102b

In [11]:
!DEBUG=true ON_MP=true python train.py 

2024-02-19 04:57:04.210120: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-19 04:57:04.210187: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-19 04:57:04.211790: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
config.json: 100%|█████████████████████████| 1.59k/1.59k [00:00<00:00, 9.89MB/s]
configuration_phi.py: 100%|████████████████| 2.03k/2.03k [00:00<00:00, 15.2MB/s]
A new version of the following files was downloaded from https://huggingface.co/TheBloke/phi-2-GPTQ:
- configuration_phi.py
. Make sure to double-check they do not contain any added malicious code. To av