In [None]:
!pip install pymongo[srv] bitsandbytes

In [2]:
from pymongo import MongoClient, UpdateOne
from pymongo.server_api import ServerApi
from kaggle_secrets import UserSecretsClient
from datetime import datetime
import time
import torch
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer
import threading
import math

In [3]:
# Constants
BATCH_SIZE = 48
instance_name = "valerie_1_2xt4"
RUN_COLLECTION_NAME = 'base_mistral_train_run'  # Replace with your run collection name

In [4]:
# Connect to MongoDB
user_secrets = UserSecretsClient()
mongo_url = user_secrets.get_secret("mongodb")
client = MongoClient(mongo_url, server_api=ServerApi('1'))
db = client['gsm8k_dataset']  # Replace with your database name
collection = db[RUN_COLLECTION_NAME]

In [5]:
# Fetch and update a batch of documents atomically
def fetch_batch(instance_name):
    current_time = datetime.utcnow()

    # Create an aggregation pipeline to fetch and update documents atomically
    pipeline = [
        {"$match": {"status": "pending"}},
        {"$limit": BATCH_SIZE},
        {"$set": {"status": "in progress", "instance": instance_name, "start_time": current_time}},
        {"$merge": {"into": RUN_COLLECTION_NAME, "whenMatched": "merge", "whenNotMatched": "fail"}}
    ]

    # Run the aggregation pipeline
    collection.aggregate(pipeline)

    # Fetch the documents that were just updated
    batch = collection.find({
        "status": "in progress",
        "instance": instance_name
    })

    return list(batch)

In [6]:
# Function to check if there are any pending batches
def has_pending_batches():
    pending_count = collection.count_documents({"status": "pending"})
    return pending_count > 0

In [7]:
model_path = "/kaggle/input/mistral/pytorch/7b-instruct-v0.1-hf/1"

In [8]:
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Load models on GPU 0 and GPU 1
model_0 = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.float16,
    device_map="cuda:0",
    quantization_config=quantization_config)
model_1 = AutoModelForCausalLM.from_pretrained(
    model_path, 
    torch_dtype=torch.float16,
    device_map="cuda:1",
    quantization_config=quantization_config
)

In [11]:
# Function to tokenize the inputs and move to respective device
def tokenize_inputs(data, device):
    questions = [item['tagged_question'] for item in data]
    return tokenizer(questions, return_tensors="pt", padding=True, truncation=True).to(device)

In [12]:
# Define functions to run the model generation in batches
def generate_responses(model, inputs, device, batch_size):
    responses = []
    num_batches = math.ceil(len(inputs.input_ids) / batch_size)
    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, len(inputs.input_ids))
        batch_input_ids = inputs.input_ids[start_idx:end_idx].to(device)
        batch_attention_mask = inputs.attention_mask[start_idx:end_idx].to(device)
        batch_responses = model.generate(
            input_ids=batch_input_ids,
            attention_mask=batch_attention_mask,
            max_length=800,
            do_sample=True,
            temperature=0.05,
            top_p=0.95,
            top_k=60,
            repetition_penalty=1.15,
            num_return_sequences=1
        )
        decoded_responses = [tokenizer.decode(response, skip_special_tokens=True) for response in batch_responses]
        responses.extend(decoded_responses)
        torch.cuda.empty_cache()
    return responses

In [13]:
# Define the threading functions for each model
def generate_responses_0():
    global responses_0
    responses_0 = generate_responses(model_0, inputs_0, "cuda:0", batch_size=BATCH_SIZE/2)

def generate_responses_1():
    global responses_1
    responses_1 = generate_responses(model_1, inputs_1, "cuda:1", batch_size=BATCH_SIZE/2)

In [14]:
def update_responses(responses, instance_name):
    bulk_ops = []
    for doc_id, response in responses:
        bulk_ops.append(UpdateOne(
            {"_id": doc_id},
            {
                "$set": {
                    "response": response,
                    "status": "completed",
                    "instance": instance_name
                }
            }
        ))

    collection.bulk_write(bulk_ops)
    print(f"Updated responses for instance {instance_name}")

In [None]:
while has_pending_batches():
    batch = fetch_batch(instance_name=instance_name)
    # Prepare data for processing
    data = [{"_id": doc["_id"], "tagged_question": f"<s>[INST]{doc['question']} \n do it step by step[/INST] "} for doc in batch]
    len_data = len(data)
    data_0 = data[:len_data//2]
    data_1 = data[len_data//2:]
    
    inputs_0 = tokenize_inputs(data_0, "cuda:0")
    inputs_1 = tokenize_inputs(data_1, "cuda:1")

    # Create threads for each function
    thread_0 = threading.Thread(target=generate_responses_0)
    thread_1 = threading.Thread(target=generate_responses_1)

    # Start the threads
    thread_0.start()
    thread_1.start()

    # Wait for both threads to complete
    thread_0.join()
    thread_1.join()
    
    # Now responses_0, responses_1, batch_times_0, and batch_times_1 should be available
    responses_0 = [(data_0[i]["_id"], responses_0[i]) for i in range(len(responses_0))]
    responses_1 = [(data_1[i]["_id"], responses_1[i]) for i in range(len(responses_1))]
    all_responses = responses_0 + responses_1
    
    update_responses(all_responses, instance_name)

    print(f"Instance {instance_name} completed a batch.")