In [1]:
%%capture
import torch
major_version, minor_version = torch.cuda.get_device_capability()
# Must install separately since Colab has torch 2.2.1, which breaks packages
!pip install "unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git"
if major_version >= 8:
    # Use this for new GPUs like Ampere, Hopper GPUs (RTX 30xx, RTX 40xx, A100, H100, L40)
    !pip install --no-deps packaging ninja einops flash-attn xformers trl peft accelerate bitsandbytes
else:
    # Use this for older GPUs (V100, Tesla T4, RTX 20xx)
    !pip install --no-deps xformers trl peft accelerate bitsandbytes
pass
# Llama 3 Video Tutorial https://www.youtube.com/watch?v=aQmoog_s8HE

In [2]:
def get_device_map() -> str:
  return 'cuda' if torch.cuda.is_available() else 'cpu'

device = get_device_map()

In [3]:
device

'cuda'

In [4]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = False # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
# fourbit_models = [
#     "unsloth/mistral-7b-bnb-4bit",
#     "unsloth/mistral-7b-instruct-v0.2-bnb-4bit",
#     "unsloth/llama-2-7b-bnb-4bit",
#     "unsloth/gemma-7b-bnb-4bit",
#     "unsloth/gemma-7b-it-bnb-4bit", # Instruct version of Gemma 7b
#     "unsloth/gemma-2b-bnb-4bit",
#     "unsloth/gemma-2b-it-bnb-4bit", # Instruct version of Gemma 2b
#     "unsloth/llama-3-8b-bnb-4bit", # [NEW] 15 Trillion token Llama-3
# ] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    # model_name = "meta-llama/Meta-Llama-3-8B-Instruct",
    model_name = "meta-llama/Llama-2-7b-hf",
    max_seq_length = max_seq_length,
    device_map='auto',
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    token = "****", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.1.7: Fast Llama patching. Transformers: 4.47.1.
   \\   /|    GPU: Tesla T4. Max memory: 14.748 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.1+cu121. CUDA: 7.5. CUDA Toolkit: 12.1. Triton: 3.1.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors.index.json:   0%|          | 0.00/23.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]

model-00001-of-00003.safetensors:   0%|          | 0.00/4.94G [00:00<?, ?B/s]

model-00002-of-00003.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00003.safetensors:   0%|          | 0.00/3.59G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/183 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/948 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

In [5]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 1, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 5,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None, # And LoftQ
)

Unsloth 2025.1.7 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [6]:
!pip install langchain huggingface_hub transformers datasets sentence-transformers chromadb

Collecting chromadb
  Downloading chromadb-0.6.3-py3-none-any.whl.metadata (6.8 kB)
Collecting build>=1.0.3 (from chromadb)
  Downloading build-1.2.2.post1-py3-none-any.whl.metadata (6.5 kB)
Collecting chroma-hnswlib==0.7.6 (from chromadb)
  Downloading chroma_hnswlib-0.7.6-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (252 bytes)
Collecting fastapi>=0.95.2 (from chromadb)
  Downloading fastapi-0.115.7-py3-none-any.whl.metadata (27 kB)
Collecting uvicorn>=0.18.3 (from uvicorn[standard]>=0.18.3->chromadb)
  Downloading uvicorn-0.34.0-py3-none-any.whl.metadata (6.5 kB)
Collecting posthog>=2.4.0 (from chromadb)
  Downloading posthog-3.11.0-py2.py3-none-any.whl.metadata (2.9 kB)
Collecting onnxruntime>=1.14.1 (from chromadb)
  Downloading onnxruntime-1.20.1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (4.5 kB)
Collecting opentelemetry-exporter-otlp-proto-grpc>=1.2.0 (from chromadb)
  Downloading opentelemetry_exporter_otlp_proto_grpc-1.29.0-py3

In [7]:
!pip install -q transformers langchain langchain_chroma sentence_transformers
!pip install bitsandbytes
!pip install -U bitsandbytes
!pip install -U langchain-community
!pip install pypdf

Collecting langchain-community
  Downloading langchain_community-0.3.16-py3-none-any.whl.metadata (2.9 kB)
Collecting dataclasses-json<0.7,>=0.5.7 (from langchain-community)
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting httpx-sse<0.5.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.0-py3-none-any.whl.metadata (9.0 kB)
Collecting langchain<0.4.0,>=0.3.16 (from langchain-community)
  Downloading langchain-0.3.16-py3-none-any.whl.metadata (7.1 kB)
Collecting langchain-core<0.4.0,>=0.3.32 (from langchain-community)
  Downloading langchain_core-0.3.32-py3-none-any.whl.metadata (6.3 kB)
Collecting pydantic-settings<3.0.0,>=2.4.0 (from langchain-community)
  Downloading pydantic_settings-2.7.1-py3-none-any.whl.metadata (3.5 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain-community)
  Downloading marshmallow-3.26.0-py3-none-any.whl.metadata (7.3 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses-

In [8]:
!pip install trl



In [9]:
from langchain.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.chains import RetrievalQA
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity
import torch

In [10]:
from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer, TrainingArguments
from datasets import Dataset, DatasetDict
from sentence_transformers import SentenceTransformer
import torch
import os

In [11]:
from trl import SFTTrainer
from transformers import TrainingArguments
from huggingface_hub import login

In [12]:
# Function to load and preprocess the document
def load_and_preprocess_document(file_path: str, chunk_size: int = 1000, chunk_overlap: int = 200):
    """
    Loads a PDF document, splits it into smaller chunks, and returns split documents.

    Args:
        file_path (str): Path to the PDF file.
        chunk_size (int): Maximum size of each text chunk.
        chunk_overlap (int): Overlap between consecutive text chunks.

    Returns:
        List of split documents.
    """
    loader = PyPDFLoader(file_path)
    docs = loader.load()
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
    return text_splitter.split_documents(docs)

In [13]:
# Prepare Dataset with Retrieval-Augmented Context (RAG)
def prepare_dataset(questions, contexts, answers):
    """
    Prepare a dataset for finr-tuning LLaMA, combining questions, context, and answers.

    Args:
        questions (list): List of questions.
        contexts (list): List of retrieved contexts corresponding to each question.
        answers (list): List of answers corresponding to each question.

    Returns:
        Dataset: dataset object ready for fine-tuning.
    """

    data = {"input_text": [], "output_text": []}
    question_count = 0
    for question, context, answer in zip(questions, contexts, answers):
        question_count += 1
        print(f"Question number: {question_count}")
        input_text = f"Question: {question}\nContext: {context}\nAnswer:"
        output_text = answer
        data["input_text"].append(input_text)
        data["output_text"].append(output_text)

    return Dataset.from_dict(data)

In [14]:
# Fine-Tune LLaMA Model
def fine_tune_llama(model, tokenizer, train_dataset, output_dir="/content/llama_finetuned"):
    """
    Fine-tunes LLaMA using the provided dataset.

    Args:
        model: LLaMA model to be finr-tuned.
        tokenizer: Tokenizer for the model.
        train_dataset: Training dataset.
        output_dir (str): Directory to save the finr-tuned model.

    Returns:
        None
    """
    def preprocess_function(examples):
        # print("Examples:", examples)  # Debug input
        if "input_text" not in examples or "output_text" not in examples:
            raise KeyError("Missing 'input_text' or 'output_text' in the dataset.")
        tokenized_inputs = tokenizer(examples["input_text"], truncation=True, padding="max_length", max_length=512)
        labels = tokenizer(examples["output_text"], truncation=True, padding="max_length", max_length=512)
        tokenized_inputs["labels"] = labels["input_ids"]
        # print("Tokenized:", tokenized_inputs)  # Debug output
        print("Preprocess function is done")
        return tokenized_inputs

    tokenized_train = train_dataset.map(preprocess_function, batched=True)

    # Move the model to the correct device
    if torch.cuda.is_available():
        model = model.to("cuda")
    else:
        model = model.to("cpu")

    trainer = SFTTrainer(
        model = model,
        tokenizer = tokenizer,
        train_dataset = tokenized_train,
        dataset_text_field = "text",
        max_seq_length = 2048,
        dataset_num_proc = 2,
        packing = False, # Can make training 5x faster for short sequences.
        args = TrainingArguments(
            per_device_train_batch_size = 1,
            gradient_accumulation_steps = 8,
            warmup_steps = 5,
            max_steps = 60,
            learning_rate = 2e-4,
            fp16 = not torch.cuda.is_bf16_supported(),
            bf16 = torch.cuda.is_bf16_supported(),
            logging_steps = 1,
            optim = "adamw_8bit",
            weight_decay = 0.01,
            lr_scheduler_type = "linear",
            seed = 3407,
            output_dir = output_dir,
            logging_dir="./logs",
            save_steps=10,
            save_total_limit=2,
            evaluation_strategy="no",
            ),
    )

    # training_args = TrainingArguments(
    #     output_dir = output_dir,
    #     evaluation_strategy="no",
    #     learning_rate=2e-5,
    #     per_device_train_batch_size=2,
    #     num_train_epochs=3,
    #     save_steps=10,
    #     save_total_limit=2,
    #     fp16=True,
    #     logging_dir="./logs"
    # )
    # training_args = TrainingArguments(
    #     output_dir = output_dir,
    #     evaluation_strategy="no",
    #     learning_rate=2e-5,
    #     per_device_train_batch_size=2,
    #     num_train_epochs=3,
    #     save_steps=10,
    #     save_total_limit=2,
    #     fp16=True,
    #     logging_dir="./logs"
    # )

    # trainer = Trainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=tokenized_train
    # )

    # trainer = Trainer(
    #     model=model,
    #     args=training_args,
    #     train_dataset=tokenized_train
    # )

    print("Tuning started .....")
    trainer.train()
    print("Tuning finished .....")
    trainer.save_model("/content/llama_finetuned")

In [15]:
# Perform Retrieval using Sentence-Transformer
def retrieve_contexts(questions, documents, retriever_model="all-MiniLM-L6-v2", top_k=5):
    """
    Retrieves the most relevant contexts for each question using a retriever model.

    Args:
        questions (list): List of questions.
        documents (list): List of all documents.
        retriever_model (str): SentenceTransformer model name.
        top_k (int): Number of top contexts to retrieve for each question.

    Returns:
        list: List of top-k retrieved contexts for each question.
    """
    retriever = SentenceTransformer(retriever_model)
    doc_embeddings = retriever.encode(documents)
    retrieved_contexts = []

    for question in questions:
        question_embedding = retriever.encode(question)
        scores = cosine_similarity([question_embedding], doc_embeddings)[0]
        top_indices = scores.argsort()[-top_k:][::-1]
        top_contexts = [documents[i] for i in top_indices]
        retrieved_contexts.append(" ".join(top_contexts)) # Combine top contexts
    return retrieved_contexts

In [16]:
pdf_paths = ["/content/Banking_Act_Directions_No_5_of_2024.pdf", "/content/bsd_circular_no_1_of_2025_e.pdf", "/content/bsd_circular_no_3_of_2024_e.pdf", "/content/bsd_circular_no_5_of_2024_e.pdf"]
questions = ["What is the main purpose of these Banking Act Directions?", "To whom do these directions apply?", "What are the responsibilities of the Board of Directors under these directions?", "How often must the Board of Directors meet according to the directions?", "What qualifications are required for the Chairperson of the Board?", "What is the composition requirement for the Board's committees?", "What is the required quorum for Board meetings?", "What are the disclosure requirements for licensed banks under the directions?", "What is the purpose of Circular No. 1 of 2025?", "What new requirement is introduced in this circular for licensed banks?", "What is the maximum period allowed for rescheduling credit facilities?", "What factors must be considered when rescheduling loans?", "What grievance-handling mechanism must licensed banks establish?", "How are the relief measures linked to Business Revival Units?", "What type of reporting is required from licensed banks under this circular?", "Can interest be waived during the rescheduling process?", "What is the purpose of BSD Circular No. 3 of 2024?",  "What are the permissible activities for Mobile Banking Units?", "What types of Mobile Banking Units are recognized under this circular?", "Who has the authority to approve the establishment of Mobile Banking Units?", "What governance framework must Mobile Banking Units follow?", "What is the reporting requirement for licensed banks planning to operate MBUs?", "Are Mobile Banking Units required to update customer accounts in real-time?", "What security measures are required for MBUs?", "What is the timeline for publishing quarterly and annual financial statements?", "What corrective actions can the Central Bank enforce for misleading disclosures?", "In what formats must financial statements be published?", "What key financial data must be disclosed by licensed banks?", "What are the minimum publication requirements for financial statements?"]
answers = ["The purpose is to strengthen corporate governance processes and practices in licensed banks to enhance the stability of the banking sector and financial system.", "These directions apply to all licensed banks in Sri Lanka.", "The Board is responsible for overseeing the management of the bank, ensuring compliance with laws and regulations, approving business strategies, maintaining risk governance, and promoting the safety and soundness of the bank.", "The Board must meet at least once a month, holding a minimum of 12 meetings annually.", "The Chairperson must be an independent non-executive director and must not be involved in the direct supervision of key management personnel or executive duties.", "Each Board committee must have at least three non-executive directors, with a majority being independent directors. The Chairperson of the Audit Committee must also be independent.", "At least half of the Board members must be present, with more than one-third being independent non-executive directors.", "Banks must disclose information such as financial performance, capital adequacy, related party transactions, corporate governance compliance, and other regulatory requirements in their annual reports.", "The purpose is to clarify Circular No. 04 of 2024 on Relief Measures to Assist Affected Small and Medium Enterprises (SMEs) and ensure consistent implementation across all licensed banks.", "Licensed banks are required to establish a Relief Banking Unit to extend and monitor relief measures under Circular No. 04 of 2024.", "Credit facilities can be rescheduled for a maximum period of up to ten years, unless the original agreement allowed a longer period.", "Factors include the repayment capacity of the borrower, an acceptable revival plan, and terms such as interest rates based on prevailing benchmark rates.",  "Banks must implement a transparent grievance-handling mechanism for disputes regarding the valuation of auctioned properties.", "While the Relief Banking Unit focuses on relief measures, Business Revival Units established under Circular No. 02 of 2024 will continue their role in supporting business revival.", "Licensed banks must provide monthly reports detailing borrowers who approached banks for relief, including loan amounts, interest rates, and the status of discussions.", "Yes, interest can be waived off by the bank as part of the rescheduling agreement.", "The purpose of the circular is to provide guidelines for the establishment of Mobile Banking Units (MBUs) by licensed commercial banks to strengthen processes and adopt uniform practices.", "Permissible activities include accepting deposits and withdrawals, account opening, loan/credit card applications, receiving payments, promotions, onboarding to digital channels, facilitating utility payments, and providing advisory services.", "Recognized types include banking services in vehicles, barefoot banking, units operating a few days a week in permanent locations, and ad-hoc services at public places like schools, carnivals, and exhibitions.", "The Deputy Governor of the Central Bank of Sri Lanka has the authority to approve the establishment of Mobile Banking Units.", "MBUs must adhere to a Board-approved governance and risk management framework that includes internal controls, security arrangements, reporting procedures, and compliance with customer charters.", "Banks must submit the application form BSD-MBU-01 to the Director of Bank Supervision at least 15 working days before the start of the quarter during which the MBU will operate.", "Yes, except for MBUs established at schools for students, customer accounts must be updated in real-time and incorporated into the general ledger of the affiliated branch.", "MBUs must have security arrangements for both on-site activities and cash in transit. Additionally, they must display the bank's name and the branch affiliation at the operating location and on vehicles.", "Quarterly financial statements must be published within two months of the quarter-end, and annual audited financial statements within three months of the financial year-end.", "The Central Bank can require corrections, removal of false information, or additional specified actions.", "Simplified formats for press, detailed formats for websites, and specified formats for annual reports.", "Key financial data includes performance indicators, ratios, and comparative data from the previous financial year.", "Statements must be published in Sinhala, Tamil, and English newspapers and on the bank's official website."]

In [17]:
# Load and preprocess documents
all_documents = []
for pdf_path in pdf_paths:
    print(pdf_path)
    splits = load_and_preprocess_document(pdf_path)
    all_documents.extend([chunk.page_content for chunk in splits])

/content/Banking_Act_Directions_No_5_of_2024.pdf
/content/bsd_circular_no_1_of_2025_e.pdf
/content/bsd_circular_no_3_of_2024_e.pdf
/content/bsd_circular_no_5_of_2024_e.pdf


In [18]:
# Retrieve relevant contexts for each question
contexts = retrieve_contexts(questions, all_documents)

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

In [19]:
# Prepare dataset
print("dataset preparation has started")
dataset = prepare_dataset(questions, contexts, answers)
print("done with dataset preparation")

dataset preparation has started
Question number: 1
Question number: 2
Question number: 3
Question number: 4
Question number: 5
Question number: 6
Question number: 7
Question number: 8
Question number: 9
Question number: 10
Question number: 11
Question number: 12
Question number: 13
Question number: 14
Question number: 15
Question number: 16
Question number: 17
Question number: 18
Question number: 19
Question number: 20
Question number: 21
Question number: 22
Question number: 23
Question number: 24
Question number: 25
Question number: 26
Question number: 27
Question number: 28
Question number: 29
done with dataset preparation


In [20]:
def preprocess_function(examples):
        # print("Examples:", examples)  # Debug input
        tokenized_inputs = tokenizer(examples["input_text"], truncation=True, padding="max_length", max_length=512)
        labels = tokenizer(examples["output_text"], truncation=True, padding="max_length", max_length=512)
        tokenized_inputs["labels"] = labels["input_ids"]
        # print("Tokenized:", tokenized_inputs)  # Debug output
        return tokenized_inputs

tokenized_train = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/29 [00:00<?, ? examples/s]

In [21]:
tokenized_train

Dataset({
    features: ['input_text', 'output_text', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 29
})

In [22]:
training_args = TrainingArguments(
        #output_dir = "/content/llama_finetuned",
        per_device_train_batch_size = 2,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        max_steps = 60,
        learning_rate = 2e-4,
        fp16 = not torch.cuda.is_bf16_supported(),
        bf16 = torch.cuda.is_bf16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "/content/llama_finetuned"
    )

In [23]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(32000, 4096, padding_idx=0)
        (layers): ModuleList(
          (0-31): 32 x LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=1, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=1, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
      

In [24]:
trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train
    )

In [25]:
trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs = 1
   \\   /|    Num examples = 29 | Num Epochs = 20
O^O/ \_/ \    Batch size per device = 2 | Gradient Accumulation steps = 4
\        /    Total batch size = 8 | Total steps = 60
 "-____-"     Number of trainable parameters = 2,498,560


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


Step,Training Loss
1,23.8977
2,23.6945
3,24.0284
4,47.6881
5,23.4396
6,23.5328
7,46.1802
8,23.6582
9,23.1055
10,45.7699


TrainOutput(global_step=60, training_loss=14.851806743939717, metrics={'train_runtime': 653.9232, 'train_samples_per_second': 0.734, 'train_steps_per_second': 0.092, 'total_flos': 1.0518215415300096e+16, 'train_loss': 14.851806743939717, 'epoch': 19.8})

In [26]:
trainer.save_model("/content/llama_finetuned")