In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [3]:
import os
import pandas as pd
import torch
from dataset_preprocessing import TokenInfo
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools
import pandas as pd
from tqdm import tqdm

## Model

In [4]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [6]:
vocab = tokenizer.get_vocab()
len(vocab)

50295

In [7]:
# tokenizer.decode(token_info.get_prefixes(top_tokens[1000][0], 9, 10)[0])

In [8]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

## Train model

In [9]:
from peft import LoraConfig, PeftConfig
import transformers

In [10]:
from post_training import get_lora_config, get_training_arguments
from dataset import get_baseline_dataset
from trl import SFTTrainer



In [11]:
lora_config = get_lora_config()
training_arguments = get_training_arguments("./tmp")

In [12]:
training_arguments.save_steps = 400

In [13]:
model.cuda();

In [14]:
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [15]:
dataset = get_baseline_dataset()
train_data, eval_data = dataset["train"], dataset["test"]

reading pickle


In [16]:
tokenizer.pad_token = tokenizer.eos_token

In [17]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
    dataset_text_field="text",
    max_seq_length=1024, # tweak this
    # TODO: think harder about the datacollator
    # data_collator=transformers.DataCollatorForSeq2Seq(
    #     tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    # ),
)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [68]:
from torch import nn
class ModifiedLora(nn.Module):
    def __init__(self, orig_lora, activation_fn, is_fc1):
        super().__init__()
        self.orig_lora = orig_lora
        self.is_fc1 = is_fc1
        self.activation_fn = activation_fn
    
    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        # This is just copied from the original lora.Linear forward
        previous_dtype = x.dtype
        if self.orig_lora.disable_adapters:
            assert False
            if self.orig_lora.merged:
                self.orig_lora.unmerge()
            result = self.orig_lora.base_layer(x, *args, **kwargs)
        elif self.orig_lora.merged:
            assert False
            result = self.orig_lora.base_layer(x, *args, **kwargs)
        else:
            assert len(self.orig_lora.active_adapters) == 1
            active_adapter = next(iter(self.orig_lora.active_adapters))
            if active_adapter not in self.orig_lora.lora_A.keys():
                assert False
            lora_A = self.orig_lora.lora_A[active_adapter]
            lora_B = self.orig_lora.lora_B[active_adapter]
            dropout = self.orig_lora.lora_dropout[active_adapter]
            scaling = self.orig_lora.scaling[active_adapter]
            h1 = self.orig_lora.base_layer(x, *args, **kwargs)
            h2 = lora_B(lora_A(dropout(x))) * scaling
            if self.is_fc1:
                result = self.activation_fn(h1) + self.activation_fn(h2)
            else:
                result = self.activation_fn(h1 + h2)
        result = result.to(previous_dtype)
        return result

In [50]:
mlp = model.model.layers[0].mlp

In [56]:
orig_activ_fn = mlp.activation_fn

In [51]:
mlp.activation_fn = torch.nn.Identity()

In [69]:
lora_linear_fc1 = ModifiedLora(mlp.fc1, orig_activ_fn, is_fc1=True)
lora_linear_fc2 = ModifiedLora(mlp.fc2, orig_activ_fn, is_fc1=False)

In [71]:
mlp.fc1 = lora_linear_fc1
mlp.fc2 = lora_linear_fc2

In [None]:
def improve_lora_adapter(mlp):
    orig_activ_fn = mlp.activation_fn
    mlp.activation_fn = torch.nn.Identity()
    lora_linear_fc1 = ModifiedLora(mlp.fc1, orig_activ_fn, is_fc1=True)
    lora_linear_fc2 = ModifiedLora(mlp.fc2, orig_activ_fn, is_fc1=False)
    mlp.fc1 = lora_linear_fc1
    mlp.fc2 = lora_linear_fc2

In [None]:
def improve_lora_adapters():
    

In [34]:
self.activation_fn(

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2048)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): lora.Linear(
            (base_layer): Linear(in_features=2048, out_features=8192, bias=True)
            (lora_dropout): ModuleDict(
              (default): Dropout(p=0.05, inplace=False)
            )
            (lora_A): ModuleDict(
              (default): Linear(in_features=2048, out_features=8, bias=False)
            )
 