In [3]:
import os,sys
os.environ['TOKENIZERS_PARALLELISM'] = "false"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
os.environ['HF_HOME'] = '/mnt/Data1/akann1w0w1ck/AlanTuring/.cache'
os.environ['TRANSFORMERS_CACHE'] = '/mnt/Data1/akann1w0w1ck/AlanTuring/.cache/transformers'


import bitsandbytes as bnb
from sklearn.metrics import accuracy_score
from argparse import Namespace

from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import glob
from torch.utils.data import DataLoader, Dataset as TorchDataset
import yaml
from datasets import Dataset  # type: ignore
import torch
from argparse import ArgumentParser
from lightning.pytorch import loggers as pl_loggers
import lightning.pytorch as pl
import transformers
import logging
logging.getLogger("transformers").setLevel(logging.CRITICAL)

from sklearn.metrics import precision_recall_fscore_support

from transformers import BitsAndBytesConfig

from peft import get_peft_config, prepare_model_for_int8_training, get_peft_model, LoraConfig, TaskType

from datasets import interleave_datasets, load_dataset

from transformers import get_constant_schedule_with_warmup
import pandas as pd
import warnings

import gc
model_id = 'TheBloke/Wizard-Vicuna-13B-Uncensored-HF'


warnings.filterwarnings("ignore", category=UserWarning)
map_modelid_targetmodule = {
    'TheBloke/Wizard-Vicuna-7B-Uncensored-HF': ['k_proj', 'v_proj'],
    'TheBloke/Wizard-Vicuna-13B-Uncensored-HF': ['k_proj', 'v_proj']
}
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [2]:
# Load Model
# Creating Model
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16)

model_id = 'TheBloke/Wizard-Vicuna-13B-Uncensored-HF'
model = transformers.AutoModelForCausalLM.from_pretrained(model_id,
                                                        trust_remote_code=True,
                                                        quantization_config=bnb_config,
                                                        # device_map={'':0},  
                                                            device_map = 'auto'
                                                        )

# Implementing Lora version
peft_config = LoraConfig(
    r=8, 
    lora_alpha=32, 
    target_modules=map_modelid_targetmodule[model_id ], 
    lora_dropout=0, 
    bias="none", 
    inference_mode=False,
    task_type=TaskType.CAUSAL_LM
)
# prepare int-8 model for training
# model = prepare_model_for_int8_training(model)

model = get_peft_model(model, peft_config)
print_trainable_parameters(model)

# Creating Tokenizer
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_id, use_fast=True, )


The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function.
Loading checkpoint shards: 100%|██████████| 3/3 [00:18<00:00,  6.11s/it]


trainable params: 6553600 || all params: 6678533120 || trainable%: 0.09812933292752765


In [4]:
# Load dataset for preprocessed research papers
dir_data = '../../data'
dataset_rp = Dataset.load_from_disk(os.path.join(
                dir_data, 'researchpapers', 'preprocessed' ,f"rp_{model_id.replace('/','_')}_train.arrow"))
dataloader_rp = DataLoader(dataset_rp, batch_size=1, shuffle=False)

In [7]:
# Load dataset for instruct to finetune
dir_data = '../../data'
dataset_ft = Dataset.load_from_disk(os.path.join(
                dir_data, 'instruct', 'preprocessed' , f"wLM70k_nofilt_{model_id.replace('/','_')}_train.arrow"))
dataloader_ft = DataLoader(dataset_ft, batch_size=1, shuffle=False)

In [5]:
len(dataset_rp)

11888

In [8]:
len(dataset_ft)

43979

In [10]:
# create function which gets loss and evaluate with different settings

def get_loss(llm, batch ):
    
    outputs = llm(**batch, output_hidden_states=False, output_attentions=False)
    loss = outputs.loss

    # if torch.isnan(loss):
    #     return None
    return loss

### Evaluating RP loss

In [5]:
# Get a batch of outputs from dataset_rp
batch_rp = next(iter(dataloader_rp))
loss_rp = get_loss(model, batch_rp)
print(loss_rp)

# Move the tensors to CPU
batch_rp = {key: value.cpu() for key, value in batch_rp.items() if torch.is_tensor(value)}
if torch.is_tensor(loss_rp):
    loss_rp = loss_rp.cpu()


tensor(11.5469, dtype=torch.float16, grad_fn=<ToCopyBackward0>)


In [22]:
# Testing w/ labels tag changed
# Get a batch of outputs from dataset_rp
batch_rp = next(iter(dataloader_rp))


batch_rp["labels"] = batch_rp["input_ids"].clone()


loss_rp = get_loss(model, batch_rp)
print(loss_rp)

# Move the tensors to CPU
batch_rp = {key: value.cpu() for key, value in batch_rp.items() if torch.is_tensor(value)}
if torch.is_tensor(loss_rp):
    loss_rp = loss_rp.cpu()


tensor(3.7188, dtype=torch.float16, grad_fn=<ToCopyBackward0>)


In [23]:
loss_rp

tensor(3.7188, dtype=torch.float16, grad_fn=<ToCopyBackward0>)

In [24]:
# Delete the variables
try:
    del batch_rp
except Exception as e:
    print(e)
try:
    del loss_rp
except Exception as e:
    print(e)
gc.collect()

# Clear the GPU cache
torch.cuda.empty_cache()

### Evaluating instruct loss

In [29]:
batch_ft = next(iter(dataloader_ft))
batch_ft['labels'] = batch_ft['input_ids'].clone()
loss_ft = get_loss(model, batch_ft)
loss_ft

# Move the tensors to CPU
batch_ft = {key: value.cpu() for key, value in batch_ft.items() if torch.is_tensor(value)}
if torch.is_tensor(loss_ft):
    loss_ft = loss_ft.cpu()

In [30]:
loss_ft

tensor(1.1777, dtype=torch.float16, grad_fn=<ToCopyBackward0>)

In [None]:
# Delete the variables
try:
    del batch_ft
except Exception as e:
    print(e)
try:
    del loss_ft
except Exception as e:
    print(e)
gc.collect()

# Clear the GPU cache
torch.cuda.empty_cache()

In [31]:
torch.cuda.empty_cache()
