In [1]:
import os
import time

os.environ["CUDA_VISIBLE_DEVICES"] = '0'

SAVING_DIR='/home/data/taxonomy/'
os.environ["TRANSFORMERS_CACHE"] = SAVING_DIR + "hf_cache/"
os.environ["HF_HOME"] = SAVING_DIR + "hf_cache/"

import torch
from transformers_modified.src.transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaTokenizer, AutoConfig, PretrainedConfig, AutoTokenizer
import math
import re
from operator import attrgetter
import torch.nn as nn
import torch.nn.functional as F
import transformers_modified.src.transformers as transformers

from typing import Optional, Dict, Sequence
from argparse import ArgumentParser
from pathlib import Path

from tqdm import tqdm
from dataclasses import dataclass, field
from itertools import chain
from functools import partial

import numpy as np
import torch
from tqdm import tqdm
from peft import PeftModel, get_peft_model, TaskType, LoraConfig

from ste_utils import prepare_llama_ste
import datasets
from datasets import load_dataset, load_from_disk

#import transformers
from transformers_modified.src.transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    default_data_collator,
    LlamaForCausalLM,
    TrainerCallback,
    DataCollatorForSeq2Seq
)
import transformers_modified.src.transformers as transformers


IGNORE_INDEX = -100



In [29]:

@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
    """
    model_name_or_path: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
            )
        },
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    use_fast_tokenizer: bool = field(
        default=False,
        metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
    )
    model_revision: str = field(
        default="main",
        metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
    )
    token: str = field(
        default=None,
        metadata={
            "help": (
                "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token "
                "generated when running `huggingface-cli login` (stored in `~/.huggingface`)."
            )
        },
    )
    trust_remote_code: bool = field(
        default=False,
        metadata={
            "help": (
                "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
                "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
                "execute code present on the Hub on your local machine."
            )
        },
    )
    torch_dtype: Optional[str] = field(
        default=None,
        metadata={
            "help": (
                "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
                "dtype will be automatically derived from the model's weights."
            ),
            "choices": ["auto", "bfloat16", "float16", "float32"],
        },
    )

@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    dataset_name: Optional[str] = field(
        default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
    )
    dataset_config_name: Optional[str] = field(
        default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
    )
    streaming: bool = field(default=False, metadata={"help": "Enable streaming mode"})
    block_size: Optional[int] = field(
        default=None,
        metadata={
            "help": (
                "Optional input sequence length after tokenization. "
                "The training dataset will be truncated in block of this size for training. "
                "Default to the model max input length for single sentence inputs (take into account special tokens)."
            )
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    validation_split_percentage: Optional[int] = field(
        default=5,
        metadata={
            "help": "The percentage of the train set used as validation set in case there's no validation split"
        },
    )
    preprocessing_num_workers: Optional[int] = field(
        default=None,
        metadata={"help": "The number of processes to use for the preprocessing."},
    )

    dataset_percentage: Optional[int] = field(
        default=100,
        metadata={"help": "The number of percentage to take from entire dataset"},
    
    )
    
    seed: Optional[int] = field(
        default=54,
    )
    
    load_from_disk: bool = field(
        default=False
    )

@dataclass
class DataCollatorWithMaskForCausalLM(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, batch):
        input_ids = []
        labels = []
        attention_masks = []

        for item_dict in batch:
            input_ids.append(torch.tensor(item_dict["input_ids"]))
            attention_masks.append(torch.tensor(item_dict["attention_mask"]))
            label = torch.tensor(item_dict["labels"])
            label[:-1] = IGNORE_INDEX
            labels.append(label)

        input_ids = torch.vstack(input_ids)
        attention_masks = torch.vstack(attention_masks)
        labels = torch.vstack(labels)
            
        data_dict = {
            'input_ids': input_ids,
            'attention_mask': attention_masks,
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict
    



In [36]:
dataset_name = "allenai/tulu-v2-sft-mixture" #"wikitext" #'/home/LLM_Compression/logs/wikitext_gpt2'
dataset_config_name = None#'wikitext-2-raw-v1'
valid_split = 10
block_size = 100
dataset_percentage = 0.1

data_args = DataTrainingArguments(
    dataset_name = dataset_name,
    dataset_config_name = dataset_config_name,
    validation_split_percentage = valid_split,
    block_size = block_size,
    dataset_percentage = dataset_percentage
)


In [4]:
model_name_or_path = 'gpt2'
cache_dir = SAVING_DIR + "hf_cache/"

model_args = ModelArguments(
    model_name_or_path = model_name_or_path,
    config_name = None, 
    tokenizer_name = None,
    use_fast_tokenizer = True,
    token = None,
    trust_remote_code = True,
    cache_dir= cache_dir
)


In [5]:
 model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        cache_dir=model_args.cache_dir,
        device_map="auto"	    
)

In [6]:

# Load pretrained tokenizer
tokenizer_kwargs = {
    "cache_dir": model_args.cache_dir,
    "use_fast": model_args.use_fast_tokenizer,
    "revision": model_args.model_revision,
    "token": model_args.token,
    "trust_remote_code": model_args.trust_remote_code,
}

if model_args.tokenizer_name:
    tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
elif model_args.model_name_or_path:
    tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)

In [7]:
from data_utils import encode_with_prompt_completion_format, encode_with_messages_format, load_hf_datasets

In [37]:
raw_datasets = load_hf_datasets(data_args)

In [40]:
len(tokenizer(raw_datasets['validation'][4]['messages'][0]['content'])[0])

208

In [41]:
if (
    "prompt" in raw_datasets["train"].column_names
    and "completion" in raw_datasets["train"].column_names
):
    encode_function = partial(
        encode_with_prompt_completion_format,
        tokenizer=tokenizer,
        max_seq_length=data_args.block_size,
    )
elif "messages" in raw_datasets["train"].column_names:
    encode_function = partial(
        encode_with_messages_format,
        tokenizer=tokenizer,
        max_seq_length=data_args.block_size,
    )

lm_datasets = raw_datasets.map(
    encode_function,
    batched=False,
    num_proc=data_args.preprocessing_num_workers,
    remove_columns=[
        name
        for name in raw_datasets["train"].column_names
        if name not in ["input_ids", "labels", "attention_mask"]
    ],
    desc="Tokenizing and reformatting instruction data",
)

lm_datasets.set_format(type="pt")
lm_datasets = lm_datasets.filter(
    lambda example: (example["labels"] != -100).any()
)


Tokenizing and reformatting instruction data:   0%|          | 0/293 [00:00<?, ? examples/s]

Tokenizing and reformatting instruction data:   0%|          | 0/33 [00:00<?, ? examples/s]

Filter:   0%|          | 0/293 [00:00<?, ? examples/s]

Filter:   0%|          | 0/33 [00:00<?, ? examples/s]

In [43]:
lm_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'labels', 'attention_mask'],
        num_rows: 171
    })
    validation: Dataset({
        features: ['input_ids', 'labels', 'attention_mask'],
        num_rows: 1
    })
})

In [65]:
inp = [torch.tensor([   27,    91,  7220,    91,    29,   198,  5492,  2555,   198,    27,
           91,   562, 10167,    91,    29,   198,  6610,   835,   284,   787,
          262,  1628,   517,  4050,   561,   307,   284, 16094,  4875,  8514,
          284,  2987,   262,  7124,   290,  6082,   286,   262, 25047, 14240,
         3186,  4635,   416,   262, 13053,  2449,   305, 19216,    13,   198,
          198,  3198,  3164,   561,   307,   284,  2251,   281,  2691, 17432,
          810,  7008,   460,  3538,  5001,   262,  3186,  3264,   422,   262,
         2449,   305, 19216,    11, 17286,   278,  4569,  6082,  9619,   290,
         8868,  3484,    13,   383, 17432,   714,   307,  3562,   284,  2148,
         6496,  1321,   319,   262,  3227,  7767,   973,    11,   262,  6142]), torch.tensor([   27,    91,  7220,    91,    29,   198, 16447,   257, 11933,  1430,
          326,  2753,   734,  3146,   290,  5860,   262,  2160,   286,   511,
        24438,    13,   198,    27,    91,   562, 10167,    91,    29,   198,
         8818,  2160,  5189, 22266,  3565,     7,    64,    11,   275,     8,
         1391,   198,   220,  1441,   257,     9,    64,  1343,   275,     9,
           65,    26,   198,    92, 50256])]

In [67]:
torch.nn.utils.rnn.pad_sequence(
            inp, batch_first=True, padding_value=tokenizer.pad_token_id
)

tensor([[   27,    91,  7220,    91,    29,   198,  5492,  2555,   198,    27,
            91,   562, 10167,    91,    29,   198,  6610,   835,   284,   787,
           262,  1628,   517,  4050,   561,   307,   284, 16094,  4875,  8514,
           284,  2987,   262,  7124,   290,  6082,   286,   262, 25047, 14240,
          3186,  4635,   416,   262, 13053,  2449,   305, 19216,    13,   198,
           198,  3198,  3164,   561,   307,   284,  2251,   281,  2691, 17432,
           810,  7008,   460,  3538,  5001,   262,  3186,  3264,   422,   262,
          2449,   305, 19216,    11, 17286,   278,  4569,  6082,  9619,   290,
          8868,  3484,    13,   383, 17432,   714,   307,  3562,   284,  2148,
          6496,  1321,   319,   262,  3227,  7767,   973,    11,   262,  6142],
        [   27,    91,  7220,    91,    29,   198, 16447,   257, 11933,  1430,
           326,  2753,   734,  3146,   290,  5860,   262,  2160,   286,   511,
         24438,    13,   198,    27,    91,   562, 

In [74]:
@torch.no_grad()
def format_logit(example):
    global model
    global tokenizer

    ids = torch.nn.utils.rnn.pad_sequence(
            example['input_ids'], batch_first=True, padding_value=tokenizer.pad_token_id
    ).to(model.device)
    masks =torch.nn.utils.rnn.pad_sequence(
            example['attention_mask'], batch_first=True, padding_value=0
    ).to(model.device)
    logits = model(ids, attention_mask=masks).logits
    new_ls = []
    for i, l in enumerate(logits):
        #print(e)
        mask = (ids[i] != tokenizer.pad_token_id)
        cur_logit = l[mask]
        new_ls.append(cur_logit.cpu().tolist())
    example['logits'] = new_ls
    return example

dataset_with_logits = lm_datasets.map(format_logit, batched=True, batch_size=2, desc=f"Obtaining logits")

Obtaining logits:   0%|          | 0/171 [00:00<?, ? examples/s]

Obtaining logits:   0%|          | 0/1 [00:00<?, ? examples/s]

In [78]:
dataset_with_logits['train'][0]['logits']

tensor([[ -14.1875,  -14.0000,  -16.2500,  ...,  -21.6250,  -21.5000,
          -14.2500],
        [ -56.2500,  -53.5000,  -53.7500,  ...,  -65.5000,  -66.0000,
          -51.7500],
        [ -84.0000,  -87.5000,  -84.0000,  ...,  -93.5000,  -95.5000,
          -84.0000],
        ...,
        [ -97.5000,  -96.0000, -100.0000,  ...,  -99.0000, -101.5000,
          -93.0000],
        [-114.5000, -113.0000, -117.0000,  ..., -117.5000, -119.5000,
         -111.5000],
        [ -74.0000,  -70.5000,  -77.0000,  ...,  -79.5000,  -82.0000,
          -73.0000]])

In [133]:
@dataclass
class DistillDataCollatorWithMaskForCausalLM(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, batch):
        input_ids = []
        labels = []
        attention_masks = []
        logits = []

        for item_dict in batch:
            input_ids.append(torch.tensor(item_dict["input_ids"]))
            attention_masks.append(torch.tensor(item_dict["attention_mask"]))
            label = torch.tensor(item_dict["labels"])
            label[:-1] = IGNORE_INDEX
            labels.append(label)
            logits.append(torch.tensor(item_dict['logits']).unsqueeze(0))

        input_ids = torch.vstack(input_ids)
        attention_masks = torch.vstack(attention_masks)
        labels = torch.vstack(labels)
        logits = torch.vstack(logits)
            
        data_dict = {
            'input_ids': input_ids,
            'attention_mask': attention_masks,
            'logits': logits
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict
    
@dataclass
class DistillDataCollatorSeq2Seq(object):
    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, batch):



        input_ids = []
        labels = []
        attention_masks = []
        logits = []

        for item_dict in batch:
            input_ids.append(torch.tensor(item_dict["input_ids"]))
            attention_masks.append(torch.tensor(item_dict["attention_mask"]))
            label = torch.tensor(item_dict["labels"])
            labels.append(label)
            logits.append(torch.tensor(item_dict['logits']))

        input_ids = torch.nn.utils.rnn.pad_sequence(
            input_ids, batch_first=True, padding_value=tokenizer.pad_token_id
        )
        attention_masks =torch.nn.utils.rnn.pad_sequence(
            attention_masks, batch_first=True, padding_value=0
        )
        labels =torch.nn.utils.rnn.pad_sequence(
            labels, batch_first=True, padding_value=-100
        )
        logits =torch.nn.utils.rnn.pad_sequence(
            logits, batch_first=True, padding_value=0
        )
            
        data_dict = {
            'input_ids': input_ids,
            'attention_mask': attention_masks,
            'logits': logits
        }
        if labels is not None:
            data_dict['labels'] = labels
        return data_dict
    


In [134]:
#data_collator = DistillDataCollatorWithMaskForCausalLM(tokenizer)
data_collator = DataCollatorWithMaskForCausalLM(tokenizer)
# data_collator = DataCollatorForSeq2Seq(
#     tokenizer=tokenizer, model=model, padding="longest"
# )
data_collator = DistillDataCollatorSeq2Seq(tokenizer)

In [122]:
tokenizer.pad_token = tokenizer.eos_token

In [123]:
ld = dataset_with_logits

In [143]:
training_args = TrainingArguments(
    output_dir = './test_distill',
    learning_rate = 3e-4, 
    seed = 2, 
    num_train_epochs = 1,
    per_device_train_batch_size = 2,
    per_device_eval_batch_size = 2,
    gradient_accumulation_steps = 1,
    gradient_checkpointing=False,
    save_strategy = 'steps',
    save_steps = 1000,
    evaluation_strategy = 'steps',
    eval_steps = 2,
    weight_decay = 0.1,
    warmup_ratio = 0.03,
    lr_scheduler_type = "cosine",
    logging_steps = 1,
    do_train = True,
    do_eval = True,
    report_to="none",
    remove_unused_columns=False
)

train_dataset = ld["train"]
eval_dataset = ld["validation"]

class MegaTrainer(Trainer):
    def __init__(self, model, temperature=None, lambda_param=None,  *args, **kwargs):
        super().__init__(model=model, *args, **kwargs)
        self.loss_function = nn.KLDivLoss(reduction="batchmean")
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.temperature = temperature
        self.lambda_param = lambda_param
        
    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        #print(inputs)
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs["labels"]
        else:
            labels = None

        label_mask = inputs["labels"] != -100

        logits = inputs.pop("logits")[label_mask]

        outputs = model(**inputs)

        #https://huggingface.co/docs/transformers/tasks/knowledge_distillation_for_image_classification
        soft_teacher = F.softmax(logits / self.temperature, dim=-1)
        soft_student = F.log_softmax(outputs.logits[label_mask] / self.temperature, dim=-1)

        # Compute the loss
        distillation_loss = self.loss_function(soft_student, soft_teacher) * (self.temperature ** 2)

        print(distillation_loss)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            unwrapped_model = unwrap_model(model)
            if _is_peft_model(unwrapped_model):
                model_name = unwrapped_model.base_model.model._get_name()
            else:
                model_name = unwrapped_model._get_name()
            if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values():
                student_target_loss = self.label_smoother(outputs, labels, shift_labels=True)
            else:
                student_target_loss = self.label_smoother(outputs, labels)
        else:
            if isinstance(outputs, dict) and "loss" not in outputs:
                raise ValueError(
                    "The model did not return a loss from the inputs, only the following keys: "
                    f"{','.join(outputs.keys())}. For reference, the inputs it received are {','.join(inputs.keys())}."
                )
            # We don't use .loss here since the model may return tuples instead of ModelOutput.
            student_target_loss = outputs["loss"] if isinstance(outputs, dict) else outputs[0]

        self.state.distill_loss = distillation_loss.detach().cpu().item()
        self.state.CE_loss = student_target_loss.detach().cpu().item()
        # Calculate final loss
        loss = (1. - self.lambda_param) * student_target_loss + self.lambda_param * distillation_loss
        return (loss, outputs) if return_outputs else loss

    def log(self, logs: Dict[str, float]) -> None:
        """
        Log `logs` on the various objects watching training.

        Subclass and override this method to inject custom behavior.

        Args:
            logs (`Dict[str, float]`):
                The values to log.
        """
        if self.state.epoch is not None:
            logs["epoch"] = round(self.state.epoch, 2)
        if self.args.include_num_input_tokens_seen:
            logs["num_input_tokens_seen"] = self.state.num_input_tokens_seen

        #print(self.state)
        if (self.state.global_step % self.state.eval_steps) == 0:
            #print(self.state.global_step, self.state.logging_steps, (self.state.global_step % self.state.logging_steps) == 0)
            prefix = 'eval_'
        else:
            prefix = 'train_'
        logs[f'{prefix}distill_loss'] = self.state.distill_loss
        logs[f'{prefix}CE_loss'] = self.state.CE_loss

        output = {**logs, **{"step": self.state.global_step}}
        self.state.log_history.append(output)
        self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)


class PrinterCallback(TrainerCallback):
    def on_log(self, args, state, control, logs=None, **kwargs):
        #print(logs)
        pass
    
# Initialize our Trainer
trainer = MegaTrainer(
    model=model,
    args=training_args,
    train_dataset=ld['train'],
    eval_dataset=ld['validation'],
    tokenizer=tokenizer,
    # Data collator will default to DataCollatorWithPadding, so we change it.
    data_collator=data_collator, 
    temperature=1, 
    lambda_param=1
)
#trainer.add_callback(PrinterCallback)

In [144]:
trainer.train()

  input_ids.append(torch.tensor(item_dict["input_ids"]))
  attention_masks.append(torch.tensor(item_dict["attention_mask"]))
  label = torch.tensor(item_dict["labels"])
  logits.append(torch.tensor(item_dict['logits']))


tensor(0.2909, device='cuda:0', grad_fn=<MulBackward0>)


Step,Training Loss,Validation Loss


tensor(0.2672, device='cuda:0', grad_fn=<MulBackward0>)


KeyboardInterrupt: 

In [38]:
trainer.state.distill_loss

5

In [17]:
outputs.predictions.shape

(54, 32, 50257)

In [19]:
logits = torch.rand(5000, 1024, 50257)

In [None]:
raw_datase

In [97]:
outputs.predictions[0].shape

(32, 50257)

In [98]:
out = model.forward(torch.tensor([[0, 1]]))

In [101]:
out.logits.shape

torch.Size([1, 2, 50257])

In [63]:
logits[0][0].shape

(12, 8, 32, 64)

In [45]:
logits[0][0].shape

(12, 8, 32, 64)

In [49]:
len(eval_dataset[0]['input_ids'])

32

In [53]:
eval_dataset[1]

{'input_ids': [247,
  8230,
  281,
  253,
  2798,
  273,
  27015,
  41151,
  23749,
  1157,
  253,
  2846,
  1907,
  644,
  387,
  253,
  43213,
  273,
  1302,
  314,
  1495,
  1157,
  1445,
  285,
  10336,
  323,
  689,
  247,
  8014,
  1107,
  1157,
  342],
 'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1],
 'labels': [247,
  8230,
  281,
  253,
  2798,
  273,
  27015,
  41151,
  23749,
  1157,
  253,
  2846,
  1907,
  644,
  387,
  253,
  43213,
  273,
  1302,
  314,
  1495,
  1157,
  1445,
  285,
  10336,
  323,
  689,
  247,
  8014,
  1107,
  1157,
  342]}

In [72]:
logits[1][1].shape

(12, 8, 32, 64)

In [74]:
64 * 8 * 12

6144

In [73]:
model

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 512)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-5): 6 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=512, out_features=1536, bias=True)
          (dense): Linear(in_features=512, out_features=512, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=512, out_features=2048, bias=True)
          (dense_4h_to_h): Linear(in_features=2048, out_features=512, bias=True)
          (a

In [None]:

def run_train(
    model_args,
    data_args,
    training_args,
    config,
):
    
    # Load pretrained model
    # if config.model_type == 'Llama':
    #     model_type = LlamaForCausalLM
    # else:
    #     model_type = AutoModelForCausalLM
    

    model = AutoModelForCausalLM.from_pretrained(
        model_args.model_name_or_path,
        torch_dtype=torch.bfloat16,
        token=model_args.token,
        cache_dir=model_args.cache_dir,
        device_map="auto"	    
)
    
    if config.zero_outliers:
        make_zero_outliers(model, config.outlier_fraction)

    if config.use_clip_softmax:
        model.set_clipped_sm(gamma=config.clip_softmax_gamma, eta=config.clip_softmax_eta)

    if config.ste.enable:
        outlier_ids, layer_bit = prepare_llama_ste(config.ste.path_to_act_scales, config.ste.fp_features_num, **config.ste.layer_bits)
        model.enable_ste(outlier_ids=outlier_ids, layer_bit=layer_bit, block_size=config.ste.block_size)

    if config.use_lora:
        task_type = TaskType.CAUSAL_LM
        target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "up_proj", "down_proj", "gate_proj"]
        lora_config = LoraConfig(
            task_type=task_type,
            inference_mode=False,
            r=config.lora_rank,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            target_modules=config.lora_target_modules,
            init_lora_weights=True,
        )
        model = get_peft_model(model, lora_config)

    # Load pretrained tokenizer
    tokenizer_kwargs = {
        "cache_dir": model_args.cache_dir,
        "use_fast": model_args.use_fast_tokenizer,
        "revision": model_args.model_revision,
        "token": model_args.token,
        "trust_remote_code": model_args.trust_remote_code,
    }

    if model_args.tokenizer_name:
        tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
    elif model_args.model_name_or_path:
        tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, **tokenizer_kwargs)

    #Load and preprocessing dataset
    raw_datasets = load_hf_datasets(data_args)
    tokenized_datasets = tokenize_datasets(data_args, raw_datasets, tokenizer)
    lm_datasets = format_datasets(data_args, tokenized_datasets, tokenizer)

    data_collator = DataCollatorWithMaskForCausalLM(
        tokenizer=tokenizer
    )

    if config.norm_tweek:
        layernorm_names = [f"model.layers.{layer_block_num}.input_layernorm.weight" for layer_block_num in range(len(model.model.layers))]
        layernorm_names += [f"model.layers.{layer_block_num}.post_attention_layernorm.weight" for layer_block_num in range(len(model.model.layers))]

        #Set model parameters to be learned
        for name, param in model.named_parameters():
            if name not in layernorm_names:
                # freeze base model's layers
                param.requires_grad = False
            else:
                param.requires_grad = True

    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()

    print(f"trainable_params: {trainable_params}")

    
    #Train
    train_dataset = lm_datasets["train"]
    eval_dataset = lm_datasets["validation"]

    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        # Data collator will default to DataCollatorWithPadding, so we change it.
        data_collator=default_data_collator
    )

    trainer.save_model()
    train_result = trainer.train()
    trainer.save_model()  # Saves the tokenizer too for easy upload

