In [None]:
!pip install -U "transformers>=4.44.0" "accelerate>=0.30.0" "peft>=0.11.0" "bitsandbytes>=0.43.0" "datasets" "evaluate" "sentencepiece"




In [None]:
from huggingface_hub import login

login()  # paste your HF token when prompted

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

CELL 1 - Load CodeSearchNet (Python) and create input/target text

We’ll use Nan-Do/code-search-net-python and the code + docstring columns.

In [None]:
from datasets import load_dataset

dataset_id = "Nan-Do/code-search-net-python"
raw = load_dataset(dataset_id)  # single 'train' split with partition column

def split_by_partition(ds, part):
    return ds.filter(lambda ex: ex["partition"] == part)

train_raw = split_by_partition(raw["train"], "train")
valid_raw = split_by_partition(raw["train"], "valid")
test_raw  = split_by_partition(raw["train"], "test")

len(train_raw), len(valid_raw), len(test_raw)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


(410175, 22977, 22091)

CELL 2 — Preprocess: Extract One-Line Docstrings & Build Training Prompts

In [None]:
def extract_one_line_docstring(docstring: str) -> str:
    if docstring is None:
        return ""
    doc = docstring.strip().strip('"').strip("'")
    first = ""
    for line in doc.split("\n"):
        line = line.strip()
        if line:
            first = line
            break
    return first

def make_example(ex):
    code = ex["code"]
    one_line = extract_one_line_docstring(ex["docstring"])
    if not one_line:
        return {"skip": True}

    prompt = f'Write a one-line Python docstring for this function:\n\n{code}\n\n"""'

    target = f'{one_line}\"\"\"'
    full_text = prompt + " " + target

    return {
        "full_text": full_text,
        "skip": False,
    }

train_proc = train_raw.map(make_example)
valid_proc = valid_raw.map(make_example)
test_proc = test_raw.map(make_example)

train_proc = train_proc.filter(lambda ex: not ex["skip"])
valid_proc = valid_proc.filter(lambda ex: not ex["skip"])
test_proc = test_proc.filter(lambda ex: not ex["skip"])

train_proc = train_proc.select(range(min(2000, len(train_proc))))
valid_proc = valid_proc.select(range(min(200, len(valid_proc))))
test_proc  = test_proc.select(range(min(200, len(test_proc))))

len(train_proc), len(valid_proc), len(test_proc)

(2000, 200, 200)

CELL 3 — Load Llama-3.2-1B-Instruct in 4-bit

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "meta-llama/Llama-3.2-1B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)

# Ensure pad token exists
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)

`torch_dtype` is deprecated! Use `dtype` instead!


CELL 4 — Tokenize the Dataset

In [None]:
from functools import partial

max_length = 512

def tokenize_fn(ex, tokenizer):
    out = tokenizer(
        ex["full_text"],
        truncation=True,
        max_length=max_length,
        padding="max_length",
    )
    out["labels"] = out["input_ids"].copy()
    return out

tokenize_with_tok = partial(tokenize_fn, tokenizer=tokenizer)

train_tok = train_proc.map(tokenize_with_tok, batched=True, remove_columns=train_proc.column_names)
valid_tok = valid_proc.map(tokenize_with_tok, batched=True, remove_columns=valid_proc.column_names)
test_tok  = test_proc.map(tokenize_with_tok,  batched=True, remove_columns=test_proc.column_names)

train_tok[0]


Map:   0%|          | 0/200 [00:00<?, ? examples/s]

{'input_ids': [128000,
  8144,
  264,
  832,
  8614,
  13325,
  4733,
  928,
  369,
  420,
  734,
  1473,
  755,
  5542,
  20925,
  4432,
  11,
  1646,
  15737,
  2703,
  5980,
  11,
  308,
  57025,
  5980,
  11,
  1168,
  77,
  78093,
  1151,
  4047,
  11925,
  518,
  14008,
  5725,
  997,
  262,
  3270,
  262,
  1183,
  1771,
  264,
  597,
  41078,
  15795,
  19228,
  34465,
  369,
  3663,
  18324,
  382,
  262,
  551,
  913,
  5542,
  4432,
  25,
  6352,
  430,
  5727,
  264,
  1207,
  54734,
  369,
  1855,
  3967,
  1732,
  11,
  449,
  1202,
  836,
  382,
  257,
  320,
  860,
  304,
  2592,
  2082,
  311,
  1518,
  5542,
  4432,
  3187,
  5021,
  6070,
  696,
  257,
  29696,
  512,
  286,
  366,
  10613,
  4432,
  29,
  6018,
  286,
  81593,
  366,
  9164,
  16,
  29,
  6018,
  286,
  34491,
  256,
  81593,
  366,
  57839,
  1878,
  16,
  14611,
  32021,
  198,
  286,
  34491,
  256,
  81593,
  366,
  57839,
  1878,
  17,
  14611,
  32021,
  198,
  286,
  34491,
  256,
  81593,
  

CELL 5 — Attach QLoRA Adapters (PEFT)

In [None]:
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",
        "gate_proj", "up_proj", "down_proj",
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

model = get_peft_model(model, lora_config)
model.print_trainable_parameters()


trainable params: 11,272,192 || all params: 1,247,086,592 || trainable%: 0.9039


CELL 6 — Training Setup (Trainer)

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
)

output_dir = "llama32_1b_docstring_qlora"

training_args = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    num_train_epochs=1,
    warmup_ratio=0.03,
    logging_steps=50,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="steps",
    bf16=True,
    save_total_limit=2,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    data_collator=data_collator,
)

trainer.train()


  return fn(*args, **kwargs)


Epoch,Training Loss,Validation Loss
1,1.5128,1.583723


TrainOutput(global_step=125, training_loss=1.5637748413085937, metrics={'train_runtime': 3820.9569, 'train_samples_per_second': 0.523, 'train_steps_per_second': 0.033, 'total_flos': 6048266059776000.0, 'train_loss': 1.5637748413085937, 'epoch': 1.0})

CELL 7 — Inference Helper Function

In [None]:
import torch

def make_inference_prompt(code: str) -> str:
    return f'Write a one-line Python docstring for this function:\n\n{code}\n\n"""'

def generate_docstring(code: str, max_new_tokens: int = 64):
    prompt = make_inference_prompt(code)
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=0.0,
            pad_token_id=tokenizer.eos_token_id,
        )

    full = tokenizer.decode(output[0], skip_special_tokens=True)

    # Take only the continuation after the prompt
    if full.startswith(prompt):
        continuation = full[len(prompt):]
    else:
        continuation = full

    continuation = continuation.strip()

    # 🔽 NEW: stop at first triple-quote or newline
    for sep in ['"""', '\n']:
        if sep in continuation:
            continuation = continuation.split(sep)[0]
            break

    # Clean leftover quotes / spaces
    continuation = continuation.strip().strip('"').strip()

    return continuation



CELL 8 — Quick Sample Predictions

In [None]:
import textwrap

for i in range(3):
    ex = test_raw[i]
    print("CODE:")
    print(textwrap.indent(ex["code"], "  "))
    print("\nGROUND TRUTH:")
    print(" ", extract_one_line_docstring(ex["docstring"]))
    print("\nPREDICTION:")
    print(" ", generate_docstring(ex["code"]))
    print("\n" + "="*60 + "\n")

CODE:
  def get_vid_from_url(url):
          """Extracts video ID from URL.
          """
          return match1(url, r'youtu\.be/([^?/]+)') or \
            match1(url, r'youtube\.com/embed/([^/?]+)') or \
            match1(url, r'youtube\.com/v/([^/?]+)') or \
            match1(url, r'youtube\.com/watch/([^/?]+)') or \
            parse_query_param(url, 'v') or \
            parse_query_param(parse_query_param(url, 'u'), 'v')

GROUND TRUTH:
  Extracts video ID from URL.

PREDICTION:
  Extracts video ID from URL.


CODE:
  def sina_xml_to_url_list(xml_data):
      """str->list
      Convert XML to URL List.
      From Biligrab.
      """
      rawurl = []
      dom = parseString(xml_data)
      for node in dom.getElementsByTagName('durl'):
          url = node.getElementsByTagName('url')[0]
          rawurl.append(url.childNodes[0].data)
      return rawurl

GROUND TRUTH:
  str->list

PREDICTION:
  str->list


CODE:
  def makeMimi(upid):
      """From http://cdn37.atwikiimg.com/sit

CELL 9 — Compute BLEU & ROUGE

In [None]:
import evaluate

bleu = evaluate.load("sacrebleu")
rouge = evaluate.load("rouge")

def eval_on_subset(raw_ds, n_samples=50):
    preds, refs = [], []

    subset = raw_ds.select(range(min(n_samples, len(raw_ds))))

    for ex in subset:
        ref = extract_one_line_docstring(ex["docstring"])
        if not ref:
            continue
        pred = generate_docstring(ex["code"])
        preds.append(pred)
        refs.append([ref])   # sacrebleu format: list of list

    bleu_res = bleu.compute(predictions=preds, references=refs)
    rouge_res = rouge.compute(
        predictions=preds,
        references=[r[0] for r in refs]
    )
    return bleu_res, rouge_res

bleu_res, rouge_res = eval_on_subset(test_raw)
bleu_res, rouge_res


({'score': 12.407732153452939,
  'counts': [320, 270, 230, 194],
  'totals': [2083, 2033, 1984, 1936],
  'precisions': [15.362457993278925,
   13.280865715691096,
   11.59274193548387,
   10.020661157024794],
  'bp': 1.0,
  'sys_len': 2083,
  'ref_len': 363},
 {'rouge1': np.float64(0.779739287347534),
  'rouge2': np.float64(0.7439124439625177),
  'rougeL': np.float64(0.7771134915746791),
  'rougeLsum': np.float64(0.778847467011564)})

Lets take a look at some random samples

In [None]:
import random, textwrap

def show_random_examples(n=5):
    indices = random.sample(range(len(test_raw)), k=min(n, len(test_raw)))
    for i in indices:
        ex = test_raw[i]
        code = ex["code"]
        gt = extract_one_line_docstring(ex["docstring"])
        pred = generate_docstring(code, max_new_tokens=32)

        print("CODE:")
        print(textwrap.indent(code, "  "))
        print("\nGROUND TRUTH:")
        print(" ", gt)
        print("\nPREDICTION:")
        print(" ", pred)
        print("\n" + "="*60 + "\n")

show_random_examples(5)


CODE:
  def wait(self, timeout=-1):
          """wait for result to complete."""
          start = time.time()
          if self._ready:
              return
          local_ids = filter(lambda msg_id: msg_id in self._client.outstanding, self.msg_ids)
          local_ready = self._client.wait(local_ids, timeout)
          if local_ready:
              remote_ids = filter(lambda msg_id: msg_id not in self._client.results, self.msg_ids)
              if not remote_ids:
                  self._ready = True
              else:
                  rdict = self._client.result_status(remote_ids, status_only=False)
                  pending = rdict['pending']
                  while pending and (timeout < 0 or time.time() < start+timeout):
                      rdict = self._client.result_status(remote_ids, status_only=False)
                      pending = rdict['pending']
                      if pending:
                          time.sleep(0.1)
                  if not pending:
             