# INFO-RAG Implementation with MLX

Link: [ml-explore.github.io](https://ml-explore.github.io/mlx/build/html/index.html)

# Installation
- This notebook is ran on a project that has `uv` installed. So there's no `pip` install section

In [1]:
import mlx.core as mx
import joblib

# def fun(a, b, d1, d2):
#   x = mx.matmul(a, b, stream=d1)
#   for _ in range(500):
#       b = mx.exp(b, stream=d2)
#   return x, b


# a = mx.random.uniform(shape=(4096, 512))
# b = mx.random.uniform(shape=(512, 4))
# d1=mx.gpu
# d2=mx.cpu

# fun(a,b,d1,d2)

In [2]:
data_list = joblib.load("../INFO-RAG/uns_data")

In [3]:
data = data_list[:10000]
len(data)

10000

# INFORAG Functions

In [4]:
import transformers
from torch.utils.data import Dataset
from dataclasses import dataclass

IGNORE_INDEX = -100

def tokenize(
        prompt,
        completion,
        tokenizer: transformers.PreTrainedTokenizer,
):
    """Preprocess the data by tokenizing."""
    source_output = tokenizer.encode(prompt)
    input_seq = prompt + ' ' + completion
    passage_list = prompt
    tokenize_output = tokenizer(input_seq, padding=False, return_tensors=None,max_length=512,truncation=False)
    passage_list_tokenize_output = tokenizer(passage_list, padding=False, return_tensors=None, max_length=512, truncation=False)
    IGNORE_INDEX = -100
    source_len = len(source_output) - 1

    tokenize_output["labels"] = copy.deepcopy(tokenize_output["input_ids"])
    tokenize_output["labels"] = [IGNORE_INDEX] * source_len + tokenize_output["labels"][source_len:]
    return passage_list_tokenize_output,tokenize_output

special_token_list = [1,32000,32001]
import random
class SupervisedDataset(Dataset):
    """Dataset for supervised fine-tuning."""

    def __init__(self,
                 tokenizer: transformers.PreTrainedTokenizer,
                 data_type,
                 data_list):
        super(SupervisedDataset, self).__init__()
        self.tokenizer = tokenizer
        self.data_list = data_list

        if data_type == 'train':
            self.data_list = self.data_list[:int(1.0*len(self.data_list))]
        else:
            self.data_list = self.data_list[int(0.2*len(self.data_list))+1:]

    def __len__(self):
        return len(self.data_list)

    def __getitem__(self, i):
        if i % (1000) == 0 and int(-1) == 0:
            sp = subprocess.Popen(["nvidia-smi"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
            out_str = sp.communicate()
            for out_element in out_str:
                for line in str(out_element).split('\\n'):
                    print(line, file=sys.stderr)
        return self.data_list[i]
@dataclass
class DataCollatorForSupervisedDataset(object):
    """Collate examples for supervised fine-tuning."""

    tokenizer: transformers.PreTrainedTokenizer

    def __call__(self, instances):
        return instances


In [6]:
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
import math

train_dataset = SupervisedDataset(tokenizer=tokenizer,data_type='train',data_list=data)
train_sampler = SequentialSampler(train_dataset)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)

train_dataloader = DataLoader(train_dataset,
                                  collate_fn=data_collator,
                                  sampler=train_sampler,
                                  shuffle=False,
                                  batch_size=4)

num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / 0.1) # the higher the lower steps/epoch

num_update_steps_per_epoch

25000

In [7]:
from transformers import AutoTokenizer

model_name_or_path = "unsloth/llama-3-8b-bnb-4bit"

tokenizer = AutoTokenizer.from_pretrained(
  model_name_or_path, fast_tokenizer=True)
tokenizer.pad_token = tokenizer.eos_token
# make sure tokenizer is right pad in our logic
tokenizer.padding_side = 'right'


tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/9.09M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

## Tokenize

In [15]:
import os
import copy
from tqdm import tqdm

instances_new = []
data_tag = 0

datasets = []
cac,cs,sc = [],[],[]
for step, data_list in tqdm(enumerate(train_dataloader)):
    ### data start ###
    instances_new = []
    data_tag = 0
    for i in range(len(data_list)):
        if step%10 < 4:  # Correct and Complete
            data_tag = 0
            special_token = '[REFERENCE]'
            passage_list = data_list[i][2]
            score_list = data_list[i][3]
            passage_ids_list = []
            for passage in passage_list:
                input_passage_text = (special_token + ' ' + passage)
                output_passage, output_all_token = tokenize(input_passage_text, '',
                                                            tokenizer)
                output_passage_token_ids = output_passage['input_ids']
                label_ids = output_passage_token_ids

                mask_probability = 0.3 #process 30% tokens

                masked_indices = [i for i in range(int(len(output_passage_token_ids) / 3),len(output_passage_token_ids)) if
                                  random.random() < mask_probability]
                masked_ids = []
                idx = 0
                masked_idx = []
                while idx < len(output_passage_token_ids):
                    if (not (idx in masked_indices)) or (output_passage_token_ids[idx] in special_token_list):
                        masked_ids.append(output_passage_token_ids[idx])
                        idx += 1
                    else:  # Process two tokens consecutively
                        rand_num = random.random()
                        if rand_num < 0.5:  # [MASK]
                            masked_ids.append(32000)
                        elif rand_num > 0.5 and rand_num < 0.6:  # Keep
                            masked_ids.append(output_passage_token_ids[idx])
                        else:
                            masked_ids.append(random.randint(3, 31999))  # Replace
                        masked_idx.append(idx)
                        idx += 1
                        if idx < len(output_passage_token_ids) and (
                        not (output_passage_token_ids[idx] in special_token_list)):
                            rand_num = random.random()
                            if rand_num < 0.5:  # [MASK]
                                masked_ids.append(32000)
                            elif rand_num > 0.5 and rand_num < 0.6:  # Keep
                                masked_ids.append(output_passage_token_ids[idx])
                            else:
                                masked_ids.append(random.randint(3, 31999))  # Replace
                            masked_idx.append(idx)
                            idx += 1
                passage_ids_list.append((label_ids, masked_ids))

            query_ids = tokenize(
                'Complete this text according to the above [REFERENCE]: ' + data_list[i][1], '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            selected_s_i = 0
            for s_i in range(len(score_list)):
                if score_list[s_i] == 0:
                    selected_s_i = s_i
                    break
            trace_ids = tokenize(
                'This content is generated according to my knowledge and [REFERENCE] number {}'.format(selected_s_i), '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            start_generation = random.randint(int(len(query_ids) / 2),int((3 * len(query_ids)) / 4))
            answer_label_ids = [IGNORE_INDEX]*start_generation + query_ids[start_generation:] + trace_ids
            input_passage_ids = [1]
            origin_ids = [1]
            for item in passage_ids_list:
                input_passage_ids += item[1][1:]
                origin_ids += item[0][1:]
            input_ids = input_passage_ids + query_ids + trace_ids
            labels = [IGNORE_INDEX] * len(input_passage_ids) + answer_label_ids

            #print("CoC",tokenizer.decode(input_ids),"\n")
            #print("CoC ans",tokenizer.decode(query_ids[start_generation:] + trace_ids),"\n")
            cac.append({"text": tokenizer.decode(input_ids)})
            

        elif step%10 >= 4 and step%10 < 8:  # Contextual Stimulation
            data_tag = 1
            special_token = '[REFERENCE]'
            passage_list = data_list[i][2]
            selected_passage = '[QUERY] ' + data_list[i][1]
            score_list = data_list[i][3]
            score_list_fuben = []
            for s in score_list:
                if not s == 0:
                    score_list_fuben.append(s)
            if len(score_list) > 1:
                score_list = score_list_fuben
            passage_ids_list = []
            for passage in passage_list:
                input_passage_text = (special_token + ' ' + passage)
                output_passage, output_all_token = tokenize(input_passage_text, selected_passage,
                                                            tokenizer)
                output_passage_token_ids = output_passage['input_ids']
                label_ids = output_passage_token_ids
                if passage == data_list[i][1]:
                    selected_ids = label_ids
                else:
                    passage_ids_list.append(label_ids)

            query_ids = tokenize(
                'Complete this text according to the above [REFERENCE]: ' + data_list[i][1], '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            selected_s_i = 0
            for s_i in range(len(score_list)):
                if score_list[s_i] == 0:
                    selected_s_i = s_i
                    break
            trace_ids = tokenize(
                'This content is generated according to my knowledge'.format(selected_s_i), '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            start_generation = random.randint(int(len(query_ids) / 2),int((3 * len(query_ids)) / 4))
            answer_label_ids = [IGNORE_INDEX]*start_generation + query_ids[start_generation:] + trace_ids
            input_passage_ids = [1]
            origin_ids = [1]
            for item in passage_ids_list:
                input_passage_ids += item[1:]
                origin_ids += item[1:]
            input_ids = input_passage_ids + query_ids + trace_ids
            labels = [IGNORE_INDEX] * len(input_passage_ids) + answer_label_ids

            #print("CS",tokenizer.decode(input_ids),"\n")
            #print("CS ans", tokenizer.decode(query_ids[start_generation:] + trace_ids),"\n")
            cs.append({"text":tokenizer.decode(input_ids)})



        else:  # Select and Copy
            data_tag = 2
            special_token = '[REFERENCE]'
            passage_list = data_list[i][2]
            selected_passage = '[QUERY] ' + data_list[i][1]
            score_list = data_list[i][3]
            passage_ids_list = []
            for passage in passage_list:
                input_passage_text = (special_token + ' ' + passage)
                output_passage, output_all_token = tokenize(input_passage_text, selected_passage,
                                                            tokenizer)
                output_passage_token_ids = output_passage['input_ids']
                label_ids = output_passage_token_ids
                passage_ids_list.append(label_ids)
            query_ids = tokenize(
                'Complete this text according to the above [REFERENCE]: ' + data_list[i][1], '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            selected_s_i = 0
            for s_i in range(len(score_list)):
                if score_list[s_i] == 0:
                    selected_s_i = s_i
                    break
            trace_ids = tokenize(
                'This content is generated according to [REFERENCE] number {}'.format(selected_s_i), '',
                tokenizer)[0]['input_ids'][1:] # remove start token
            start_generation = random.randint(int(len(query_ids) / 2),int((3 * len(query_ids)) / 4))
            answer_label_ids = [IGNORE_INDEX]*start_generation + query_ids[start_generation:] + trace_ids
            input_passage_ids = [1]
            for item in passage_ids_list:
                input_passage_ids += item[1:]
            input_ids = input_passage_ids + query_ids + trace_ids
            labels = [IGNORE_INDEX] * len(input_passage_ids) + answer_label_ids

            #print("SC",tokenizer.decode(input_ids),"\n")
            #print("SC ans", tokenizer.decode(query_ids[start_generation:] + trace_ids),"\n")
            sc.append({"text":tokenizer.decode(input_ids)})
        #print(tokenizer.decode(input_ids))
        #end
#datasets[0]



2500it [00:43, 57.22it/s]


In [16]:
import json
from itertools import chain

#cac = [data['cac'] if 'cac' in data.keys() for data in datasets]
print(len(cac),len(sc),len(cs))

cac_train = cac[:int((len(cac)+1)*.80)] #Remaining 80% to training set
cac_test = cac[int(len(cac)*.80+1):] #Splits 20% data to test set

sc_train = sc[:int((len(sc)+1)*.80)] #Remaining 80% to training set
sc_test = sc[int(len(sc)*.80+1):] #Splits 20% data to test set

cs_train = cs[:int((len(cs)+1)*.80)] #Remaining 80% to training set
cs_test = cs[int(len(cs)*.80+1):] #Splits 20% data to test set

print(len(cac_train),len(cac_test))
print(len(sc_train),len(sc_test))
print(len(cs_train),len(cs_test))

train = []
test = []

train.extend(cac_train)
train.extend(sc_train)
train.extend(cs_train)

test.extend(cac_test)
test.extend(sc_test)
test.extend(cs_test)

#len(cac)
with open("train.jsonl", "w") as f:
    for data in train:
        json_string = json.dumps(data)
        f.write(f"{json_string}\n")

with open("test.jsonl", "w") as f:
    for data in test:
        json_string = json.dumps(data)
        f.write(f"{json_string}\n")

with open("data.jsonl","w") as f:
    for data in chain(cac,sc,cs):
        json_string = json.dumps(data)
        f.write(f"{json_string}\n")

4000 2000 4000
3200 799
1600 399
3200 799
