In [None]:
# Install Pytorch & other libraries
%pip install "torch>=2.4.0" tensorboard

# Install Gemma release branch from Hugging Face
%pip install git+https://github.com/huggingface/transformers@main

# Install Hugging Face libraries
%pip install  --upgrade \
  "datasets==3.3.2" \
  "accelerate==1.4.0" \
  "evaluate==0.4.3" \
  "bitsandbytes==0.45.3" \
  "trl==0.15.2" \
  "peft==0.14.0" \
  "protobuf" \
  "sentencepiece"

!python -c "from accelerate.utils import write_basic_config; write_basic_config(mixed_precision='fp16')"



In [None]:
!accelerate env

In [None]:

import datetime
import json
import torch
from torch.nn.parallel import DistributedDataParallel as DDP
import numpy as np
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer, 
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling,
    BitsAndBytesConfig
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from datasets import Dataset, load_dataset
from accelerate import Accelerator


In [None]:
import os
from huggingface_hub import login

# Set your API token here
HUGGINGFACE_TOKEN = ""

# Log in with the token
def login_to_huggingface_token(token):
    try:
        login(token=token)
        print("✅ Successfully logged into Hugging Face!")
    except Exception as e:
        print(f"❌ Login failed: {e}")

# Login using the token
login_to_huggingface_token(HUGGINGFACE_TOKEN)

In [None]:
# Check GPU availability
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU count: {torch.cuda.device_count()}")
    print(f"GPU name: {torch.cuda.get_device_name(0)}")
    print(f"Current GPU memory: {torch.cuda.mem_get_info()}")


In [None]:
import os
import json
import tempfile
import glob
from datasets import Dataset
from snowflake.snowpark.context import get_active_session

def convert_to_conversation(sample):
    return {
        "messages": [
            {"role": "system", "content": sample['instruction']},
            {"role": "user", "content": sample['input']},
            {"role": "assistant", "content": sample['output']}
        ]
    }

def load_data_from_snowflake_stage():
    # Create dataset list
    dataset = []
    
    # Get the active Snowpark session
    session = get_active_session()
    print(f"DEBUG: Got active Snowpark session")
    
    # Define schema and stage
    schema = "GIT"
    stage = "CARECONNECT_TRAINING_DATA_STAGE"
    fully_qualified_stage = f"{schema}.{stage}"
    print(f"DEBUG: Using fully qualified stage: {fully_qualified_stage}")
    
    # Set the schema for the session
    session.use_schema(schema)
    print(f"DEBUG: Set active schema to {schema}")
    
    # List files in the stage
    files_in_stage = session.sql(f"LIST @{fully_qualified_stage}").collect()
    print(f"DEBUG: Found {len(files_in_stage)} files in stage")
    for i, file_info in enumerate(files_in_stage[:5]):  # Print first 5 files
        print(f"DEBUG: File {i+1}: {file_info}")
    
    # Create a temporary directory to store downloaded files
    with tempfile.TemporaryDirectory() as temp_dir:
        print(f"DEBUG: Created temp directory: {temp_dir}")
        
        jsonl_files_count = 0
        processed_files_count = 0
        
        for file_info in files_in_stage:
            file_path = file_info["name"]
            file_name = os.path.basename(file_path)
            
            # Only process .jsonl files
            if not file_name.endswith(".jsonl"):
                continue
                
            jsonl_files_count += 1
            
            # Create a subdirectory for each file to avoid conflicts
            file_dir = os.path.join(temp_dir, f"file_{jsonl_files_count}")
            os.makedirs(file_dir, exist_ok=True)
            
            # Download file from stage to local temp directory
            print(f"DEBUG: Downloading {file_name} to {file_dir}")
            try:
                session.file.get(f"@{fully_qualified_stage}/{file_name}", file_dir)
                print(f"DEBUG: Download command executed for {file_name}")
            except Exception as e:
                print(f"ERROR: Failed to download {file_name}: {str(e)}")
                continue
            
            # Find the downloaded file(s)
            downloaded_files = glob.glob(os.path.join(file_dir, "**", "*"), recursive=True)
            print(f"DEBUG: Found {len(downloaded_files)} files in download directory")
            for df in downloaded_files[:5]:  # Print first 5 files
                print(f"DEBUG: Downloaded file: {df}")
            
            # Process each downloaded file
            jsonl_files = [f for f in downloaded_files if f.endswith(".jsonl") and os.path.isfile(f)]
            if not jsonl_files:
                print(f"WARNING: No .jsonl files found in download directory for {file_name}")
                continue
                
            for jsonl_file in jsonl_files:
                print(f"DEBUG: Processing {jsonl_file}")
                if os.path.exists(jsonl_file):
                    file_size = os.path.getsize(jsonl_file)
                    print(f"DEBUG: File size: {file_size} bytes")
                else:
                    print(f"ERROR: File doesn't exist: {jsonl_file}")
                    continue
                
                # Parse the downloaded JSONL file
                records_count = 0
                try:
                    with open(jsonl_file, "r", encoding="utf-8") as file:
                        for line_num, line in enumerate(file, 1):
                            try:
                                record = json.loads(line.strip())
                                dataset.append(record)
                                records_count += 1
                                
                                # Print sample of first record in each file
                                if line_num == 1:
                                    print(f"DEBUG: Sample record from {os.path.basename(jsonl_file)}:")
                                    # Print first few keys/values for sample
                                    sample_data = {k: str(v)[:50] + "..." if isinstance(v, str) and len(str(v)) > 50 else v 
                                                  for k, v in list(record.items())[:3]}
                                    print(f"DEBUG: {sample_data}")
                                    
                            except json.JSONDecodeError as e:
                                print(f"ERROR: Failed to parse JSON at line {line_num} in {jsonl_file}: {e}")
                except Exception as e:
                    print(f"ERROR: Failed to read {jsonl_file}: {str(e)}")
                    continue
                
                print(f"DEBUG: Processed {records_count} records from {os.path.basename(jsonl_file)}")
                processed_files_count += 1
        
        print(f"DEBUG: Processed {processed_files_count} out of {jsonl_files_count} JSONL files")
        print(f"DEBUG: Total records in dataset: {len(dataset)}")
    
    # Convert dataset to conversation format
    print("DEBUG: Converting to conversation format...")
    if not dataset:
        print("ERROR: No data was loaded into the dataset!")
        return Dataset.from_dict({"messages": []})
        
    data_temp = [convert_to_conversation(data) for data in dataset]
    
    print(f"DEBUG: Created {len(data_temp)} conversation entries")
    
    # Print sample of conversation format
    if data_temp:
        print("DEBUG: Sample conversation format:")
        print(data_temp[0])
    
    data_dict = {
        "messages": [item["messages"] for item in data_temp]
    }
    
    print(f"DEBUG: Final dataset structure has {len(data_dict['messages'])} messages")
    
    # Create dataset from dictionary
    huggingface_dataset = Dataset.from_dict(data_dict)
    print(f"DEBUG: Created Hugging Face dataset with shape: {huggingface_dataset.shape}")
    
    return huggingface_dataset

# Load the dataset from Snowflake stage
print("Starting to load data from Snowflake stage...")
try:
    dataset = load_data_from_snowflake_stage()
    print(f"Dataset loaded successfully with {len(dataset)} entries")

    # Print dataset info
    print("\nDATASET INFO:")
    print(dataset)
    dataset = dataset.shuffle()
    print("\nDATASET FEATURES:")
    print(dataset.features)
    
    # Show a sample from the dataset
    if len(dataset) > 0:
        print("\nSAMPLE FROM DATASET:")
        print(dataset[0])

   
    split_dataset = dataset.train_test_split(test_size=0.2)  # 20% test, 80% train
    
    # Now you can access the train and test splits:
    train_dataset = split_dataset["train"]
    test_dataset = split_dataset["test"]
    
    print("Train dataset size:", len(train_dataset))
    print("Test dataset size:", len(test_dataset))
    
    # Before training, you might need to flatten the Messages
    def flatten_messages(example):
        return {
        "text": tokenizer.apply_chat_template(
            example["Messages"], 
            tokenize=False, 
            add_generation_prompt=False
        )
    }

    # Apply the flattening
    dataset = dataset.map(flatten_messages, remove_columns=["Messages"])
except Exception as e:
    print(f"ERROR: Failed to load dataset: {str(e)}")
    import traceback
    traceback.print_exc()


In [None]:
MODEL_NAME = "google/gemma-3-4b-pt"
FINETUNED_MODEL_DIR = "./Careconnect-gemma3-huggingface"
NUM_GPUS = torch.cuda.device_count()
MAX_SEQ_LENGTH = 512
BATCH_SIZE_PER_GPU = 2  #
GRADIENT_ACCUMULATION_STEPS = 4  # Adjust based on effective batch size needed
NUM_EPOCHS = 0.05  # Increase later from 0.05
LEARNING_RATE = 2e-4  # Slightly increased
LORA_RANK = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
SAVE_STEPS = 500
LOGGING_STEPS = 10
EVAL_STEPS = 500
WARMUP_STEPS = 100
FINETUNED_WEIGHTS_PATH = f"{FINETUNED_MODEL_DIR}/model.weights.h5"
FINETUNED_VOCAB_PATH = f"{FINETUNED_MODEL_DIR}/vocabulary.spm"

In [None]:
!nvidia-smi

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load the Gemma tokenizer
tokenizer = AutoTokenizer.from_pretrained("google/gemma-3-1b-it")

In [None]:
from transformers import BitsAndBytesConfig

# Get the current CUDA device
current_device = torch.cuda.current_device()

# Select model class based on id
model_id = MODEL_NAME
model_class = AutoModelForCausalLM

# Check if GPU benefits from bfloat16
if torch.cuda.get_device_capability()[0] >= 8:
    torch_dtype = torch.bfloat16
else:
    torch_dtype = torch.float16

# Define model init arguments
model_kwargs = dict(
    attn_implementation="eager", # Use "flash_attention_2" when running on Ampere or newer GPU
    torch_dtype=torch_dtype, # What torch dtype to use, defaults to auto
    device_map="auto", # Let torch decide how to load the model
)

# BitsAndBytesConfig: Enables 4-bit quantization to reduce model size/memory usage
model_kwargs["quantization_config"] = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=model_kwargs['torch_dtype'],
    bnb_4bit_quant_storage=model_kwargs['torch_dtype'],
)

model = model_class.from_pretrained(model_id, **model_kwargs)

In [None]:
from peft import LoraConfig

peft_config = LoraConfig(
    lora_alpha=LORA_ALPHA,
    lora_dropout=LORA_DROPOUT,
    r=LORA_RANK,
    bias="none",
    target_modules="all-linear",
    task_type="CAUSAL_LM",
    modules_to_save=["lm_head", "embed_tokens"] # make sure to save the lm_head and embed_tokens as you train the special tokens
)

In [None]:
TEST_EXAMPLES = [
    "Who are you? What is your name?"
    "What are the treatments for Diabetic Neuropathies: The Nerve Damage of Diabetes?",
    "My BMI is 22.81, and I do not smoke. I do not drink regularly. I do not have diabetes. I have not had a stroke. I do engage in physical activity. My general health is Very good. I sleep 8.0 hours per night. I do not have asthma. I do not have skin cancer. But i am very ill, what could be wrong with me?",
    "I have Variable, including almost any neurological symptom or sign, with autonomic, visual, motor, and sensory problems being the most common. What do you think I might have?",
    "What are the risks of taking Brozeet-LS 1mg Syrup?",
    'What should I do if I have a sore throat?',
]

In [None]:
from trl import SFTConfig, SFTTrainer

args = SFTConfig(
    output_dir= FINETUNED_MODEL_DIR,         # directory to save and repository id
    max_seq_length=MAX_SEQ_LENGTH,                     # max sequence length for model and packing of the dataset
    packing=True,                           # Groups multiple samples in the dataset into a single sequence
    num_train_epochs=NUM_EPOCHS,                     # number of training epochs
    per_device_train_batch_size=BATCH_SIZE_PER_GPU,          # batch size per device during training
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,          # number of steps before performing a backward/update pass
    gradient_checkpointing=True,            # use gradient checkpointing to save memory
    optim="adamw_torch_fused",              # use fused adamw optimizer
    logging_steps=LOGGING_STEPS,                       # log every 10 steps
    save_strategy="epoch",                  # save checkpoint every epoch
    learning_rate=2e-4,                     # learning rate, based on QLoRA paper
    fp16=True if torch_dtype == torch.float16 else False,   # use float16 precision
    bf16=True if torch_dtype == torch.bfloat16 else False,   # use bfloat16 precision
    max_grad_norm=0.3,                      # max gradient norm based on QLoRA paper
    warmup_ratio=0.03,                      # warmup ratio based on QLoRA paper
    lr_scheduler_type="constant",           # use constant learning rate scheduler
    push_to_hub=True,                       # push model to hub
    report_to="tensorboard",                # report metrics to tensorboard
    dataset_kwargs={
        "add_special_tokens": False, # We template with special tokens
        "append_concat_token": True, # Add EOS token as separator token between examples
    }
)

# Create Trainer object
trainer = SFTTrainer(
    model=model,
    args=args,
    train_dataset=dataset,
    peft_config=peft_config,
    processing_class=tokenizer
)



In [None]:
# Fine-tune the model
print("Starting fine-tuning...")
trainer.train()

In [None]:
from peft import PeftModel

# Load Model base model
model = model_class.from_pretrained(MODEL_NAME, low_cpu_mem_usage=True)

# Merge LoRA and base model and save
peft_model = PeftModel.from_pretrained(model, FINETUNED_MODEL_DIR)
merged_model = peft_model.merge_and_unload()
merged_model.save_pretrained("CareConnect-gemma3-4b-pt", safe_serialization=True, max_shard_size="2GB")

processor = AutoTokenizer.from_pretrained(FINETUNED_MODEL_DIR)
processor.save_pretrained("CareConnect-gemma3-4b-pt")

In [None]:
# free the memory again
del model
del trainer
torch.cuda.empty_cache()

In [None]:
import torch
from transformers import pipeline

model_id = "CareConnect-gemma3-4b-pt"

# Load Model with PEFT adapter
model = model_class.from_pretrained(
  model_id,
  device_map="auto",
  torch_dtype=torch_dtype,
  attn_implementation="eager",
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
!ls

In [None]:
!ls Careconnect-gemma3-huggingface 

In [None]:
!ls CareConnect-gemma3-4b-pt

In [None]:
!pwd

In [None]:
from snowflake.snowpark.context import get_active_session
session = get_active_session()

put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/added_tokens.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/config.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/generation_config.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00001-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00002-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00003-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00004-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00005-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00006-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00007-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00008-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00009-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model-00010-of-00010.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/model.safetensors.index.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/special_tokens_map.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/tokenizer.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/tokenizer.model","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/CareConnect-gemma3-4b-pt/tokenizer_config.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/Careconnect-gemma3-huggingface/adapter_config.json","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)
put_result = session.file.put("/home/app/Careconnect-gemma3-huggingface /adapter_model.safetensors","@SOFTWARESURGEONS_DB.GIT.CARECONNECT_GEMMA3_STAGE", auto_compress= False, overwrite= True)

put_result[0].status

In [None]:
from random import randint
from transformers import pipeline
# Load the model and tokenizer into the pipeline
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)

# Load a random sample from the test dataset
rand_idx = randint(0, len(test_dataset))
test_sample = test_dataset

# Convert as test example into a prompt with the Gemma template
stop_token_ids = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<end_of_turn>")]
prompt = pipe.tokenizer.apply_chat_template(test_sample["messages"][:2], tokenize=False, add_generation_prompt=True)

# Generate our SQL query.
outputs = pipe(prompt, max_new_tokens=256, do_sample=False, temperature=0.1, top_k=50, top_p=0.1, eos_token_id=stop_token_ids, disable_compile=True)

# Extract the user query and original answer
print(f"Original Answer:\n{test_sample['messages'][1]['content']}")
print(f"Generated Answer:\n{outputs[0]['generated_text'][len(prompt):].strip()}")