**DATA CLEANING**

In [3]:
abbreviation_dict = {
    "#": "broken bone (fracture)",
    "A&E": "accident and emergency",
    "a.c.": "before meals",
    "a.m.": "morning",
    "AF": "atrial fibrillation",
    "AMHP": "approved mental health professional",
    "APTT": "activated partial thromboplastin time",
    "ASQ": "Ages and Stages Questionnaire",
    "b.d.s": "2 times a day",
    "b.i.d.": "twice a day",
    "BMI": "body mass index",
    "BNO": "bowels not open",
    "BO": "bowels open",
    "BP": "blood pressure",
    "c/c": "chief complaint",
    "CMHN": "community mental health nurse",
    "CPN": "community psychiatric nurse",
    "CSF": "cerebrospinal fluid",
    "CSU": "catheter stream urine sample",
    "CT scan": "computerised tomography scan",
    "CVP": "central venous pressure",
    "CXR": "chest X-ray",
    "DNACPR": "do not attempt cardiopulmonary resuscitation",
    "DNAR": "do not attempt resuscitation",
    "DNR": "do not resuscitate",
    "Dr": "doctor",
    "DVT": "deep vein thrombosis",
    "Dx": "diagnosis",
    "ECG": "electrocardiogram",
    "ED": "emergency department",
    "EEG": "electroencephalogram",
    "EMU": "early morning urine sample",
    "ESR": "erythrocyte sedimentation rate",
    "EUA": "examination under anaesthetic",
    "FBC": "full blood count",
    "FOBT": "faecal occult blood test",
    "FIT": "faecal immunochemical test",
    "FY1": "foundation year 1 doctor",
    "FY2": "foundation year 2 doctor",
    "GA": "general anaesthetic",
    "gtt.": "drop(s)",
    "h.": "hour",
    "h/o": "history of",
    "Hb": "haemoglobin",
    "HCA": "healthcare assistant",
    "HCSW": "healthcare support worker",
    "HDL": "high-density lipoprotein",
    "HRT": "hormone replacement therapy",
    "Ht": "height",
    "Hx": "history",
    "i": "1 tablet",
    "ii": "2 tablets",
    "iii": "3 tablets",
    "i.m.": "injection into a muscle",
    "i.v.": "injection directly to a vein",
    "INR": "international normalised ratio",
    "IVI": "intravenous infusion",
    "IVP": "intravenous pyelogram",
    "Ix": "investigations",
    "LA": "local anaesthetic",
    "LDL": "low-density lipoprotein",
    "LFT": "liver function test",
    "LMP": "last menstrual period",
    "M/R": "modified release",
    "MRI": "magnetic resonance imaging",
    "MRSA": "methicillin-resistant Staphylococcus aureus",
    "MSU": "mid-stream urine sample",
    "n.p.o.": "nothing by mouth",
    "NAD": "nothing abnormal discovered",
    "NAI": "non-accidental injury",
    "NBM": "nil by mouth",
    "NG": "nasogastric",
    "nocte": "every night",
    "NoF": "neck of femur",
    "NSAID": "non-steroidal anti-inflammatory drug",
    "o.d.": "once a day",
    "o/e": "on examination",
    "OT": "occupational therapist",
    "p.c.": "after food",
    "p.m.": "afternoon or evening",
    "p.o.": "orally",
    "POD": "podiatrist",
    "p.r.": "rectally",
    "p.r.n.": "as needed"
}

In [4]:
import re

def expand_abbreviations(text, abbr_dict):
    for abbr, full in abbr_dict.items():
        # \b makes sure we match the full word only
        pattern = r'\b' + re.escape(abbr) + r'\b'
        text = re.sub(pattern, full, text)
    return text

df['question'] = df['question'].apply(lambda x: expand_abbreviations(x, abbreviation_dict))


In [5]:
df['answer'] = df['answer'].apply(lambda x: expand_abbreviations(x, abbreviation_dict))

In [6]:
synonym_dict = {
    "heart attack": "myocardial infarction",
    "high blood pressure": "hypertension",
    "low blood pressure": "hypotension",
    "high blood sugar": "hyperglycemia",
    "low blood sugar": "hypoglycemia",
    "stroke": "cerebrovascular accident",
    "brain attack": "cerebrovascular accident",
    "shortness of breath": "dyspnea",
    "difficulty breathing": "dyspnea",
    "fainting": "syncope",
    "passing out": "syncope",
    "fever": "pyrexia",
    "headache": "cephalalgia",
    "chest pain": "angina",
    "cold sore": "herpes labialis",
    "kidney stone": "renal calculus",
    "urinary tract infection": "UTI",
    "bladder infection": "UTI",
    "lung infection": "pneumonia",
    "high cholesterol": "hyperlipidemia",
    "blood clot": "thrombosis",
    "swollen lymph nodes": "lymphadenopathy",
    "irregular heartbeat": "arrhythmia",
    "fast heartbeat": "tachycardia",
    "slow heartbeat": "bradycardia",
    "acid reflux": "gastroesophageal reflux disease",
    "stomach flu": "gastroenteritis",
    "pink eye": "conjunctivitis",
    "nosebleed": "epistaxis",
    "runny nose": "rhinorrhea",
    "dry mouth": "xerostomia",
    "itching": "pruritus",
    "rash": "dermatitis",
    "skin inflammation": "dermatitis",
    "joint pain": "arthralgia",
    "muscle pain": "myalgia",
    "bloody urine": "hematuria",
    "bloody stool": "hematochezia",
    "vomiting blood": "hematemesis",
    "black stool": "melena",
    "fluid in lungs": "pulmonary edema",
    "yellow skin": "jaundice",
    "liver failure": "hepatic failure",
    "kidney failure": "renal failure",
    "low oxygen": "hypoxia",
    "high carbon dioxide": "hypercapnia",
    "low sodium": "hyponatremia",
    "high sodium": "hypernatremia",
    "low potassium": "hypokalemia",
    "high potassium": "hyperkalemia",
    "abnormal heartbeat": "arrhythmia",
    "dizzy": "vertigo",
    "dizziness": "vertigo",
    "numbness": "paresthesia",
    "tingling": "paresthesia",
    "diabetes": "diabetes mellitus",
    "lung cancer": "pulmonary carcinoma",
    "liver cancer": "hepatocellular carcinoma",
    "skin cancer": "melanoma",
    "breast cancer": "mammary carcinoma",
    "cervical cancer": "cervical carcinoma",
    "uterine cancer": "endometrial carcinoma",
    "brain cancer": "glioblastoma",
    "eye pressure": "intraocular pressure",
    "broken bone": "fracture",
    "back pain": "lumbalgia",
    "neck pain": "cervicalgia",
    "pregnancy loss": "spontaneous abortion",
    "miscarriage": "spontaneous abortion",
    "water breaking": "rupture of membranes",
    "labor pains": "uterine contractions",
    "baby dropping": "lightening",
    "spotting": "light vaginal bleeding",
    "night sweats": "nocturnal hyperhidrosis",
}


In [7]:
import re

def replace_synonyms(text, synonyms):
    for k, v in synonyms.items():
        pattern = r'\b' + re.escape(k) + r'\b'
        text = re.sub(pattern, v, text, flags=re.IGNORECASE)
    return text

df['question'] = df['question'].apply(lambda x: replace_synonyms(x, synonym_dict))


In [36]:
df.iloc[2000].question

'a neuro-oncology investigator has recently conducted a randomized controlled trial in which the addition of a novel alkylating agent to radiotherapy was found to prolong survival in comparison to survival radiotherapy alone hr  0.7, p  0.01. a number of surviving participants who took the alkylating agent reported that they had experienced significant nausea from the medication. the investigator surveyed all participants in both the treatment and the control group on their nausea symptoms by self-report rated mild, moderate, or severe. the investigator subsequently compared the two treatment groups with regards to nausea level. mild nausea moderate nausea severe nausea treatment group  20 30 50 control group  35 35 30 which of the following statistical methods would be most appropriate to assess the statistical significance of these results'

In [10]:
from huggingface_hub import login
login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [None]:
import pandas as pd
import numpy as np
import torch
from tqdm import tqdm
import spacy
from spacy.matcher import Matcher
from spacy.util import filter_spans
from transformers import AutoTokenizer, AutoModel

# Load SciSpaCy model
nlp = spacy.load("en_core_sci_sm")

# Load ClinicalBERT
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

# Medical abbreviations
MEDICAL_ABBREVIATIONS = {
    'MI', 'STEMI', 'NSTEMI', 'COPD', 'CHF',
    'BP', 'HR', 'RR', 'ECG', 'CXR', 'CT'
}

# Clinical tokenizer
def clinical_tokenizer(text):
    if not isinstance(text, str):
        return []
    doc = nlp(text)
    matcher = Matcher(nlp.vocab)
    patterns = [
        [{"IS_DIGIT": True}, {"TEXT": {"REGEX": "[-/]"}}, {"IS_DIGIT": True}],
        [{"TEXT": {"REGEX": "^[A-Z]{2,4}$"}}],
        [{"TEXT": {"REGEX": "^[IVX]+$"}}]
    ]
    matcher.add("CLINICAL_PATTERNS", patterns, greedy="LONGEST")
    matches = matcher(doc)
    spans = [doc[start:end] for _, start, end in matches]
    spans = filter_spans(spans)
    with doc.retokenize() as retokenizer:
        for span in spans:
            retokenizer.merge(span)
    return [token.text for token in doc if not token.is_space]

# Clinical lemmatizer
def clinical_lemmatizer_from_tokens(tokens):
    if not isinstance(tokens, list):
        return ""

    doc = nlp(" ".join(tokens))  # Recreate a doc from tokens
    lemmas = []
    for token in doc:
        if token.text.upper() in MEDICAL_ABBREVIATIONS:
            lemmas.append(token.text)
        elif token.ent_type_ in ["DISEASE", "CHEMICAL", "ANATOMY"]:
            lemmas.append(token.text)
        else:
            lemmas.append(token.lemma_.lower())
    return " ".join(lemmas)


# ClinicalBERT vectorizer
def clinical_vectorizer(text):
    if not isinstance(text, str) or not text.strip():
        return np.zeros(model.config.hidden_size)
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    return torch.mean(outputs.last_hidden_state, dim=1).squeeze().numpy()

# Main processing function for multiple columns
def process_clinical_text_multi(df, columns):
    tqdm.pandas()

    for col in columns:
        # Tokenization
        # Tokenization
        df[f'{col}_tokens'] = df[col].progress_apply(clinical_tokenizer)

        df[f'{col}_lemmas'] = df[f'{col}_tokens'].progress_apply(clinical_lemmatizer_from_tokens)


        # Vectorization
        df[f'{col}_vector'] = df[f'{col}_lemmas'].progress_apply(clinical_vectorizer)

    # Combine vectors: average of both columns
        #df['combined_vector'] = df[[f"{col}_vector" for col in columns]].progress_apply(
        #l#ambda row: np.mean([row[f"{col}_vector"] for col in columns], axis=0), axis=1
    #)

    return df



# Process and vectorize both columns
df = process_clinical_text_multi(train_df, ['question', 'answer'])

# Show final results
df.head()


2025-04-20 00:37:48.999868: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745109469.299863      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745109469.372893      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/436M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]


  0%|          | 0/10479 [00:00<?, ?it/s][A
  0%|          | 3/10479 [00:00<06:37, 26.39it/s][A
  0%|          | 7/10479 [00:00<05:18, 32.89it/s][A
  0%|          | 11/10479 [00:00<05:30, 31.63it/s][A
  0%|          | 15/10479 [00:00<05:53, 29.60it/s][A
  0%|          | 18/10479 [00:00<06:46, 25.71it/s][A
  0%|          | 22/10479 [00:00<06:17, 27.68it/s][A
  0%|          | 26/10479 [00:00<05:58, 29.15it/s][A
  0%|          | 31/10479 [00:01<05:21, 32.54it/s][A
  0%|          | 35/10479 [00:01<06:04, 28.64it/s][A
  0%|          | 38/10479 [00:01<06:21, 27.36it/s][A
  0%|          | 41/10479 [00:01<06:37, 26.25it/s][A
  0%|          | 45/10479 [00:01<06:19, 27.50it/s][A
  0%|          | 48/10479 [00:01<06:18, 27.53it/s][A
  0%|          | 51/10479 [00:01<06:13, 27.90it/s][A
  1%|          | 55/10479 [00:01<05:42, 30.45it/s][A
  1%|          | 60/10479 [00:02<05:08, 33.80it/s][A
  1%|          | 65/10479 [00:02<04:37, 37.47it/s][A
  1%|          | 70/10479 [00:02<04:19

Unnamed: 0,ID,question,answer,question_tokens,question_lemmas,question_vector,answer_tokens,answer_lemmas,answer_vector
0,8358,A 24-year-old man comes to the physician becau...,0%,"[A, 24-year-old, man, comes, to, the, physicia...",a 24-year-old man come to the physician becaus...,"[-0.013192865, -0.1455072, -0.23957036, -0.014...","[0, %]",0 %,"[0.10693209, -0.0056015886, -0.5959619, 0.3666..."
1,5260,A 64-year-old male retired farmer presents to ...,0.002,"[A, 64-year-old, male, retired, farmer, presen...",a 64-year-old male retired farmer present to t...,"[-0.0036502706, -0.14283353, -0.056406815, -0....",[0.002],0.002,"[0.44734398, -0.41792175, 0.051748876, 0.16662..."
2,7648,A healthy 29-year-old nulligravid woman comes ...,0.20%,"[A, healthy, 29-year-old, nulligravid, woman, ...",a healthy 29-year-old nulligravid woman come t...,"[-0.08983001, -0.17551515, -0.255084, -0.11629...","[0.20, %]",0.20 %,"[0.000620534, -0.27700904, -0.09529225, 0.6122..."
3,1703,A 25-year-old man with a genetic disorder pres...,1%,"[A, 25-year-old, man, with, a, genetic, disord...",a 25-year-old man with a genetic disorder pres...,"[-0.05037262, -0.2267687, -0.25451264, 0.06623...","[1, %]",1 %,"[0.2371484, 0.11435668, -0.22044116, 0.1120575..."
4,945,A 14-month-old boy is brought in by his parent...,2.50%,"[A, 14-month-old, boy, is, brought, in, by, hi...",a 14-month-old boy be bring in by his parent w...,"[-0.13312687, -0.2356358, -0.26788232, -0.0925...","[2.50, %]",2.50 %,"[-0.13349502, 0.19601762, -0.36892185, 0.44642..."


In [8]:
!pip install unsloth --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.2/46.2 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m192.7/192.7 kB[0m [31m8.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m0:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m31.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m13.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m

In [9]:
from unsloth import FastLanguageModel
import torch
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.

# 4bit pre quantized models we support for 4x faster downloading + no OOMs.
fourbit_models = [
    "unsloth/Meta-Llama-3.1-8B-bnb-4bit",      # Llama-3.1 15 trillion tokens model 2x faster!
    "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
    "unsloth/Meta-Llama-3.1-70B-bnb-4bit",
    "unsloth/Meta-Llama-3.1-405B-bnb-4bit",    # We also uploaded 4bit for 405b!
    "unsloth/Mistral-Nemo-Base-2407-bnb-4bit", # New Mistral 12b 2x faster!
    "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit",
    "unsloth/mistral-7b-v0.3-bnb-4bit",        # Mistral v3 2x faster!
    "unsloth/mistral-7b-instruct-v0.3-bnb-4bit",
    "unsloth/Phi-3.5-mini-instruct",           # Phi-3.5 2x faster!
    "unsloth/Phi-3-medium-4k-instruct",
    "unsloth/gemma-2-9b-bnb-4bit",
    "unsloth/gemma-2-27b-bnb-4bit",            # Gemma 2x faster!
] # More models at https://huggingface.co/unsloth

model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Meta-Llama-3.1-8B",
    max_seq_length = max_seq_length,
    dtype = dtype,
    load_in_4bit = load_in_4bit,
    # token = "hf_...", # use one if using gated models like meta-llama/Llama-2-7b-hf
)



🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.


2025-04-20 04:20:11.053988: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745122811.239918      31 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745122811.296624      31 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Unsloth: Failed to patch Gemma3ForConditionalGeneration.
🦥 Unsloth Zoo will now patch everything to make training faster!
==((====))==  Unsloth 2025.3.19: Fast Llama patching. Transformers: 4.51.1.
   \\   /|    Tesla P100-PCIE-16GB. Num GPUs = 1. Max memory: 15.888 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 6.0. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = FALSE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


model.safetensors:   0%|          | 0.00/5.96G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/235 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/50.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/459 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/17.2M [00:00<?, ?B/s]

In [None]:
model

In [10]:
model = FastLanguageModel.get_peft_model(
    model,
    r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
    target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
                      "gate_proj", "up_proj", "down_proj",],
    lora_alpha = 16,
    lora_dropout = 0, # Supports any, but = 0 is optimized
    bias = "none",    # Supports any, but = "none" is optimized
    # [NEW] "unsloth" uses 30% less VRAM, fits 2x larger batch sizes!
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for very long context
    random_state = 3407,
    use_rslora = False,  # We support rank stabilized LoRA
    loftq_config = None,) # And LoftQ

Unsloth 2025.3.19 patched 32 layers with 32 QKV layers, 32 O layers and 32 MLP layers.


In [44]:
model

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): LlamaForCausalLM(
      (model): LlamaModel(
        (embed_tokens): Embedding(128256, 4096, padding_idx=128004)
        (layers): ModuleList(
          (0): LlamaDecoderLayer(
            (self_attn): LlamaAttention(
              (q_proj): lora.Linear4bit(
                (base_layer): Linear4bit(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Identity()
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=16, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=16, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear

In [None]:
train_df=pd.read_csv('data/train.csv')

In [29]:
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported


# Combine question and answer into a single text prompt
df["text"] = df["question"] + " Answer: " + df["answer"].astype(str)

from datasets import Dataset
train_dataset = Dataset.from_pandas(df[["text"]])




# Step 2: Fine-tune using SFTTrainer
trainer = SFTTrainer(
    model = model,
    tokenizer = tokenizer,
    train_dataset = train_dataset,
    dataset_text_field = "text",
    max_seq_length = max_seq_length,
    dataset_num_proc = 4,
    packing = False,
    args = TrainingArguments(
        per_device_train_batch_size = 8,
        gradient_accumulation_steps = 4,
        warmup_steps = 5,
        num_train_epochs = 1,
        max_steps = 80,
        learning_rate = 2e-5,
        fp16 = not is_bfloat16_supported(),
        bf16 = is_bfloat16_supported(),
        logging_steps = 1,
        optim = "adamw_8bit",
        weight_decay = 0.01,
        lr_scheduler_type = "linear",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",
    ),
)


Unsloth: Tokenizing ["text"] (num_proc=4):   0%|          | 0/10479 [00:00<?, ? examples/s]

In [30]:
trainer_stat = trainer.train()

==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 10,479 | Num Epochs = 1 | Total steps = 80
O^O/ \_/ \    Batch size per device = 8 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (8 x 4 x 1) = 32
 "-____-"     Trainable parameters = 41,943,040/8,000,000,000 (0.52% trained)


Step,Training Loss
1,1.5493
2,1.4212
3,1.4492
4,1.5558
5,1.7154
6,1.524
7,1.6416
8,1.5531
9,1.5938
10,1.5793


In [None]:
test_df=pd.read_csv('data/test_f.csv')

In [34]:
# Inference code to add after your training section
import torch
from tqdm import tqdm

# Path to save results
submission_path = 'submission.csv'

# Set the model to evaluation mode
model.eval()

# Batch size for inference
batch_size = 2

# Create submission dataframe
submission_df = pd.DataFrame(columns=['ID', 'answer'])

# Create batches
test_batches = [test_df[i:i+batch_size] for i in range(0, len(test_df), batch_size)]

# Process each batch with tqdm
all_answers = []
all_ids = []

with torch.no_grad():
    for batch in tqdm(test_batches, desc="Running inference"):
        batch_answers = []
        batch_ids = batch['ID'].tolist()
        
        for _, row in batch.iterrows():
            question = row['question']
            
            # Format the input using the chat template
            prompt =tokenizer.chat_template = "{% for message in messages %}{{ message['role'] }}: {{ message['content'] }}\n{% endfor %}Assistant:"

            
            # Tokenize
            inputs = tokenizer(prompt, return_tensors="pt", truncation=True, 
                              max_length=512).to(model.device)
            
            # Generate
            output_ids = model.generate(
                **inputs,
                max_new_tokens=32,
                do_sample=False,
                temperature=0.7,
                top_p=0.9,
                num_beams=1,
                pad_token_id=tokenizer.eos_token_id
            )
            
            # Decode and extract only the model's response
            full_output = tokenizer.decode(output_ids[0], skip_special_tokens=True)
            
            # Extract only the assistant's response
            answer = full_output.split("[/INST]")[-1].strip() if "[/INST]" in full_output else full_output
            
            batch_answers.append(answer)
        
        all_answers.extend(batch_answers)
        all_ids.extend(batch_ids)

# Create and save submission
submission_df = pd.DataFrame({'ID': all_ids, 'answer': all_answers})
submission_df.to_csv(submission_path, index=False)
print(f"Submission saved to {submission_path}")

Running inference: 100%|██████████| 747/747 [24:58<00:00,  2.01s/it]

Submission saved to submission.csv



