In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [2]:
import torch
from transformers import TrainingArguments
from transformers import (
                        AutoModelForCausalLM,
                        AutoTokenizer,
                        BitsAndBytesConfig,
                        DataCollatorForLanguageModeling
)
from peft import PeftConfig, PeftModel

import datasets
# import transformers
from trl import SFTTrainer
# model_name = "IlyaGusev/saiga_llama3_8b"
model_name = "meta-llama/Meta-Llama-3-8B"

# Loading the Model

In [3]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    # bnb_4bit_quant_type="nf4",
    # bnb_4bit_compute_dtype=torch.float16,
)

In [4]:
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config,
    device_map="auto",
    low_cpu_mem_usage=True,
    trust_remote_code=True,
    cache_dir="/app/model"
)
model.config.use_cache = False

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [5]:
tokenizer = AutoTokenizer.from_pretrained(model_name,
                                        #   token=access_token,
                                          trust_remote_code=True,
                                          cache_dir="/app/model")
tokenizer.pad_token = tokenizer.eos_token

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Dataset

In [6]:
examples_dir = os.path.join("/app/data")

def load_example(filename):
    with open(os.path.join(examples_dir, filename) , 'r', encoding='utf-8') as f:
        return f.read()

data = load_example("rukava_data.txt")

In [7]:
def prompt_input(row):
    return ("Below is an instruction that describes a task, paired with an input that provides further context. "
            "Write a response that appropriately fits the request.\n\n"
            "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n").format_map(row)


In [8]:
prompts, outputs = [],[]
EOS_TOKEN = "</s>"
max_seq_length = 256

In [9]:
def tokenize_training_text(training_text, max_seq_length, tokenizer, separator="\n\n\n", **kwargs):
    samples_from = training_text.split(separator)
    samples_to=[]
    
    for x in samples_from:
        human_part, assistant_part = x.split("Assistant:")
        human_part = human_part.replace("Human: ", "")
        human_part = human_part.replace("\n", "")
        assistant_text = assistant_part.strip()
        # assistant_dict={}
        # for assistant_line in assistant_text.split("\n"):
        #     key, value = assistant_line.split(" - ", 1)
        #     assistant_dict[key]=value
        prompts.append(prompt_input({"instruction": human_part, "input": ""}))
        outputs.append(assistant_text + EOS_TOKEN)
        
        # samples_to.append({"input_ids": input_seq, "labels": output_seq})
    #     training_dataset = datasets.Dataset.from_list(samples_to)
#     return training_dataset
#         samples_to.append({"human": human_part, "assistant": assistant_dict})

#     training_dataset = datasets.Dataset.from_list(samples_to)
#     training_dataset = training_dataset.shuffle().map(
#         lambda x: tokenize_sample(x, max_seq_length, tokenizer), 
#         batched=False
#     )
    dataset = [{"text":s, "output":t, "example": s+t} for s, t in zip(prompts, outputs)]
    training_dataset = datasets.Dataset.from_list(dataset)
    return training_dataset
    
# def tokenize_sample(item, max_seq_length, tokenizer, add_eos_token=True):
#     assert tokenizer is not None
#     result = tokenizer(
#         item["human"],
#         truncation=True,
#         max_length=max_seq_length,
#         padding="max_length",
#     ) 

#     if add_eos_token and (len(result["input_ids"]) < max_seq_length or result["input_ids"][-1] != tokenizer.eos_token_id):
#         result["input_ids"].append(tokenizer.eos_token_id)
#         result["attention_mask"].append(1)
#     print(result)
#     return result


def load_data_collator(tokenizer, mlm=False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=mlm,
        pad_to_multiple_of=8, 
    )
    return data_collator

In [10]:
train_dataset = tokenize_training_text(training_text=data, max_seq_length=max_seq_length, tokenizer=tokenizer, cache_dir="/app/model")
data_collator = load_data_collator(tokenizer)

In [11]:
print(train_dataset)

Dataset({
    features: ['text', 'output', 'example'],
    num_rows: 2865
})


In [12]:
from peft import LoraConfig, get_peft_model

lora_alpha = 16
lora_dropout = 0.1
lora_r = 32#64

In [13]:
peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_r,
    bias="none",
    task_type="CAUSAL_LM"
)

In [14]:
output_dir = "./results"
per_device_train_batch_size = 1
gradient_accumulation_steps = 2
optim = "paged_adamw_32bit"
save_steps = 1
num_train_epochs = 2
logging_steps = 1
learning_rate = 2e-4
max_grad_norm = 0.3
max_steps = 20
warmup_ratio = 0.03
lr_scheduler_type = "linear"

In [15]:
training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    bf16=True,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    group_by_length=True,
    lr_scheduler_type=lr_scheduler_type,
)

In [16]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=max_seq_length,
    tokenizer=tokenizer,
    args=training_arguments,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.


Map:   0%|          | 0/2865 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs


In [17]:
# for name, module in trainer.model.named_modules():
#     if "norm" in name:
#         module = module.to(torch.float16)

In [17]:
# Run the training
# transformers.logging.set_verbosity_info()
trainer.train()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mdaria_tan[0m ([33mextreme_weather[0m). Use [1m`wandb login --relogin`[0m to force relogin


RuntimeError: CUDA error: unspecified launch failure
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


In [3]:
torch.cuda.get_device_properties(0)

_CudaDeviceProperties(name='NVIDIA L40', major=8, minor=9, total_memory=45385MB, multi_processor_count=142)

In [None]:
model_to_save = trainer.model.module if hasattr(trainer.model, 'module') else trainer.model  # Take care of distributed/parallel training
model_to_save.save_pretrained("outputs")

In [None]:
lora_config = LoraConfig.from_pretrained('outputs')
model = get_peft_model(model, lora_config)

In [None]:
dataset['text'][0]

"Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\nAnalyze and explain the legal reasoning behind the judgment in the given case.\n\n### Input:\nCentral Inland Water Transport Corporation Ltd. vs Brojo Nath Ganguly & Anr., 1986 AIR 1571, 1986 SCR (2) 278\n\n### Response:The Supreme Court in this case applied a broad interpretation of the term 'State' under Article 12 of the Constitution. The court reasoned that a government company undertaking public functions qualifies as 'State' based on factors like government control, public importance of activities etc. This interpretation was based on previous decisions that have defined 'State' under Article 12 broadly to include various agencies and instrumentalities beyond just statutory bodies. The court also applied the principle that unreasonable and arbitrary contractual terms can be struck down under Article 14 

In [None]:
dataset['output'][0]

"The Supreme Court in this case applied a broad interpretation of the term 'State' under Article 12 of the Constitution. The court reasoned that a government company undertaking public functions qualifies as 'State' based on factors like government control, public importance of activities etc. This interpretation was based on previous decisions that have defined 'State' under Article 12 broadly to include various agencies and instrumentalities beyond just statutory bodies. The court also applied the principle that unreasonable and arbitrary contractual terms can be struck down under Article 14 of the Constitution. The court found that Rule 9(i) of the service rules, which allowed for termination of service without reason, conferred unfettered power to terminate employment without hearing. This was deemed arbitrary and violative of principles of natural justice and right to equality under Article 14. Furthermore, the court held that the right to life and livelihood under Article 21 is a

In [None]:
text = dataset['text'][0]
device = "cuda:0"

In [None]:
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=50)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

### Instruction:
Analyze and explain the legal reasoning behind the judgment in the given case.

### Input:
Central Inland Water Transport Corporation Ltd. vs Brojo Nath Ganguly & Anr., 1986 AIR 1571, 1986 SCR (2) 278

### Response:The Supreme Court in this case applied a broad interpretation of the term 'State' under Article 12 of the Constitution. The court reasoned that a government company undertaking public functions qualifies as 'State' based on factors like government control, public importance of activities etc. This interpretation was based on previous decisions that have defined 'State' under Article 12 broadly to include various agencies and instrumentalities beyond just statutory bodies. The court also applied the principle that unreasonable and arbitrary contractual terms can be struck down under Article 14 of the Co