In [2]:
import json
from datasets import Dataset
from unsloth.chat_templates import get_chat_template
from unsloth import FastLanguageModel

max_seq_length = 50000 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True 
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/llama-3-8b-Instruct-bnb-4bit", # Choose ANY! eg teknium/OpenHermes-2.5-Mistral-7B
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)

def load_custom_dataset(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # Create a dataset structure suitable for the datasets library
    formatted_data = {
        "conversations": [
            convo["messages"] for convo in data
        ]
    }
    
    # Convert the dataset into the Hugging Face datasets format
    dataset = Dataset.from_dict(formatted_data)
    return dataset

def format_conversations(conversation):
    formatted_convo = []
    for message in conversation:
        formatted_message = {
            "from": message["role"],
            "value": message["content"]
        }
        formatted_convo.append(formatted_message)
    return formatted_convo

def formatting_prompts_func(examples):
    convos = examples["conversations"]
    texts = []
    for convo in convos:
        formatted_convo = format_conversations(convo)
        try:
            text = tokenizer.apply_chat_template(formatted_convo, tokenize=False, add_generation_prompt=False)
            texts.append(text)
        except Exception as e:
            print(f"Error processing conversation: {convo}")
            raise e
    return {"text": texts}

# Initialize the tokenizer with the correct chat template
tokenizer = get_chat_template(
    tokenizer,
    chat_template="llama-3",
    mapping={"role": "from", "content": "value", "user": "human", "assistant": "gpt"},
)

# Load and format the custom dataset
train_dataset = load_custom_dataset("datasetB_train_0-17599.json")
train_dataset = train_dataset.shuffle(seed=42)
train_dataset = train_dataset.select(range(2000, 6000))
# train_dataset = train_dataset.select(range(85000, 96000))
train_dataset = train_dataset.map(formatting_prompts_func, batched=True)
val_dataset = load_custom_dataset("datasetB_eval_17600-21999.json")
val_dataset = val_dataset.select(range(100))
val_dataset = val_dataset.map(formatting_prompts_func, batched=True)

==((====))==  Unsloth: Fast Llama patching release 2024.7
   \\   /|    GPU: NVIDIA RTX A6000. Max memory: 47.536 GB. Platform = Linux.
O^O/ \_/ \    Pytorch: 2.3.0+cu121. CUDA = 8.6. CUDA Toolkit = 12.1.
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.26.post1. FA2 = False]
 "-____-"     Free Apache license: http://github.com/unslothai/unsloth


Unsloth: unsloth/llama-3-8b-Instruct-bnb-4bit can only handle sequence lengths of at most 8192.
But with kaiokendev's RoPE scaling of 6.104, it can be magically be extended to 50000!
Map: 100%|██████████| 4000/4000 [00:00<00:00, 4227.13 examples/s]
Map: 100%|██████████| 100/100 [00:00<00:00, 5894.60 examples/s]


In [5]:
print(train_dataset[0])

{'conversations': [{'content': 'You are a helpful assistant that predicts human mobility trajectories in a city. The target city is divided into equally sized cells, creating a 200 x 200 grid. We use coordinate <x>,<y> to indicate the location of a cell within the target area. The horizontal coordinate <x> increases from left to right, and the vertical coordinate <y> increases from top to bottom. The coordinates of the top-left corner are (0, 0), and the coordinates of the bottom-right corner are (199, 199). A trajectory is a sequence of quadruples ordered by time. Each quadruple follows the format <day_id>, <time_id>, <x>, <y>. It represents a person\'s location <x>, <y> at the timeslot <time_id> of day <day_id>. The <day_id> is the index of day, representing a specific day. Each day\'s 24 hours are discretized into 48 time slots with a time interval of 30 minutes. <time_id> is the index of the time slot, ranging from 0 to 47, representing a specific half-hour in a day. Let me give yo