In [None]:
!pip install -q transformers datasets accelerate rouge-score bert-score


In [None]:
import torch
print("GPU Available:", torch.cuda.is_available())


GPU Available: True


In [None]:
from transformers import LEDTokenizer, LEDForConditionalGeneration, LEDTokenizerFast
device = "cuda" if torch.cuda.is_available() else "cpu"

model_name = "allenai/led-base-16384"

tokenizer = LEDTokenizer.from_pretrained(model_name)
model = LEDForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)


Loading weights:   0%|          | 0/299 [00:00<?, ?it/s]



In [None]:
def preprocess(batch):
    model_inputs = tokenizer(
        batch["article"],
        truncation=True,
        padding="max_length",
        max_length=2048
    )

    labels = tokenizer(
        text_target=batch["abstract"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:

import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

print("Model loaded on:", device)


Model loaded on: cuda


In [None]:
text = """
Recent studies indicate that COVID-19 vaccines significantly reduce hospitalization rates.
However, long-term immunity effects are still under investigation.
"""

inputs = tokenizer(
    text,
    return_tensors="pt",
    truncation=True,
    padding="max_length",
    max_length=1024
).to(device)

summary_ids = model.generate(
    input_ids=inputs["input_ids"],
    attention_mask=inputs["attention_mask"],
    max_length=128
)

print(tokenizer.decode(summary_ids[0], skip_special_tokens=True))


Recent studies indicate that COVID-19 vaccines significantly reduce hospitalization rates.However, long-term immunity effects are still under investigation.###


In [None]:
!pip install --upgrade --force-reinstall datasets huggingface_hub fsspec


Collecting datasets
  Downloading datasets-4.5.0-py3-none-any.whl.metadata (19 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-1.4.1-py3-none-any.whl.metadata (13 kB)
Collecting fsspec
  Downloading fsspec-2026.2.0-py3-none-any.whl.metadata (10 kB)
Collecting filelock (from datasets)
  Downloading filelock-3.20.3-py3-none-any.whl.metadata (2.1 kB)
Collecting numpy>=1.17 (from datasets)
  Downloading numpy-2.4.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (6.6 kB)
Collecting pyarrow>=21.0.0 (from datasets)
  Downloading pyarrow-23.0.0-cp312-cp312-manylinux_2_28_x86_64.whl.metadata (3.0 kB)
Collecting dill<0.4.1,>=0.3.0 (from datasets)
  Downloading dill-0.4.0-py3-none-any.whl.metadata (10 kB)
Collecting pandas (from datasets)
  Downloading pandas-3.0.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 kB[0m [31m5.0 MB/s[0m eta [36m0:00:00[0m
[?

In [None]:
from datasets import load_dataset

dataset = load_dataset("ccdv/pubmed-summarization", split="train[:1%]")
dataset


README.md: 0.00B [00:00, ?B/s]

section/train-00000-of-00005.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

section/train-00001-of-00005.parquet:   0%|          | 0.00/208M [00:00<?, ?B/s]

section/train-00002-of-00005.parquet:   0%|          | 0.00/207M [00:00<?, ?B/s]

section/train-00003-of-00005.parquet:   0%|          | 0.00/211M [00:00<?, ?B/s]

section/train-00004-of-00005.parquet:   0%|          | 0.00/210M [00:00<?, ?B/s]

section/validation-00000-of-00001.parque(…):   0%|          | 0.00/59.0M [00:00<?, ?B/s]

section/test-00000-of-00001.parquet:   0%|          | 0.00/58.9M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/119924 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/6633 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/6658 [00:00<?, ? examples/s]

Dataset({
    features: ['article', 'abstract'],
    num_rows: 1199
})

In [None]:
sample = dataset[0]
print(sample.keys())
print("\nARTICLE (short):\n", sample["article"][:500])
print("\nREFERENCE SUMMARY:\n", sample["abstract"])


dict_keys(['article', 'abstract'])

ARTICLE (short):
 a recent systematic analysis showed that in 2011 , 314 ( 296 - 331 ) million children younger than 5 years were mildly , moderately or severely stunted and 258 ( 240 - 274 ) million were mildly , moderately or severely underweight in the developing countries . 
 in iran a study among 752 high school girls in sistan and baluchestan showed prevalence of 16.2% , 8.6% and 1.5% , for underweight , overweight and obesity , respectively . 
 the prevalence of malnutrition among elementary school aged ch

REFERENCE SUMMARY:
 background : the present study was carried out to assess the effects of community nutrition intervention based on advocacy approach on malnutrition status among school - aged children in shiraz , iran.materials and methods : this case - control nutritional intervention has been done between 2008 and 2009 on 2897 primary and secondary school boys and girls ( 7 - 13 years old ) based on advocacy approach in shiraz , iran .

In [None]:
tokenized_dataset = dataset.map(
    preprocess,
    batched=True,
    remove_columns=dataset.column_names
)


Map:   0%|          | 0/1199 [00:00<?, ? examples/s]

In [None]:
from transformers import Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir="./outputs",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_steps=5,
    logging_steps=1,
    save_steps=5,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset
)

trainer.train()


Step,Training Loss
1,18.306986
2,16.21747
3,16.509434
4,14.792157
5,17.660091


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

TrainOutput(global_step=5, training_loss=16.697227668762206, metrics={'train_runtime': 25.0847, 'train_samples_per_second': 0.797, 'train_steps_per_second': 0.199, 'total_flos': 36489249423360.0, 'train_loss': 16.697227668762206, 'epoch': 0.016680567139282735})

In [None]:
def generate_summary(text, max_input_len=2048, max_output_len=256):
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=max_input_len
    ).to(device)

    summary_ids = model.generate(
        input_ids=inputs["input_ids"],
        attention_mask=inputs["attention_mask"],

        num_beams=4,
        repetition_penalty=1.2,
        length_penalty=1.1,
        max_length=max_output_len,
        early_stopping=True
    )

    return tokenizer.decode(summary_ids[0], skip_special_tokens=True)


In [None]:
def build_refinement_input(article, draft_summary):
    return (
        "ORIGINAL DOCUMENT:\n"
        + article
        + "\n\nDRAFT SUMMARY:\n"
        + draft_summary
        + "\n\nREFINED SUMMARY:"
    )


In [None]:
def refine_summary(article):
    draft = generate_summary(article)
    refinement_input = build_refinement_input(article, draft)
    refined = generate_summary(refinement_input)
    return draft, refined
def clean_text(text):
    return text.replace("Â", "").strip()



In [None]:
def build_refinement_input(article, draft_summary):
    article = clean_text(article)
    draft_summary = clean_text(draft_summary)

    return article + "\n\n" + draft_summary


In [None]:
def refine_summary(article):
    article = clean_text(article)

    draft = generate_summary(article)
    refinement_input = build_refinement_input(article, draft)
    refined = generate_summary(refinement_input)

    return draft, refined


In [None]:
article = dataset[2]["article"]

draft, refined = refine_summary(article)

print("DRAFT SUMMARY:\n", draft)
print("\nREFINED SUMMARY:\n", refined)


DRAFT SUMMARY:
 ia , in particular , have been extensively studied in schizophrenia . even though a number of studies suggest that bipolar patients experience higher rates of eps ( parkinsonism , dystonia , akathisia ) and td compared to patients with a diagnosis of schizophrenia , research within the bd population has been limited .  In fact , the risk is found to be 3 to 5 times higher in elderly patients compared to patients with atypical antipsychotics .  in addition to age , the risk is thought to be due to the presence of atypical antipsychotics , the presence of atypical antipsychotics , the presence of antipsychotic agents , the presence of antipsychotic agents , the presence of antipsychotics , the presence of antipsychotic agents , the use of anticholinergics with neuroleptics , previous physical therapies ( electroconvulsive therapy ) , the presence of other physical illness such as diabetes or an organic disorder , younger age of exposure , and the presence of extrapyramida

In [None]:
import re

def extract_entities(text):
    text = text.lower()
    numbers = re.findall(r"\b\d+\.?\d*\b", text)
    medical_terms = re.findall(r"\b[a-z]{5,}\b", text)
    return set(numbers + medical_terms)


In [None]:
def fact_score(document, summary):
    doc_entities = extract_entities(document)
    sum_entities = extract_entities(summary)

    if len(sum_entities) == 0:
        return 0.0

    overlap = doc_entities.intersection(sum_entities)
    return len(overlap) / len(sum_entities)


In [None]:
article = dataset[3]["article"]

draft, refined = refine_summary(article)

print("Draft Fact Score:", fact_score(article, draft))
print("Refined Fact Score:", fact_score(article, refined))


Draft Fact Score: 0.6724137931034483
Refined Fact Score: 0.7375


In [None]:
from collections import Counter

def redundancy_score(text, n=3):
    tokens = text.lower().split()
    if len(tokens) < n:
        return 0.0

    ngrams = [" ".join(tokens[i:i+n]) for i in range(len(tokens)-n+1)]
    counts = Counter(ngrams)

    repeated = sum(c for c in counts.values() if c > 1)
    total = len(ngrams)

    return repeated / total


In [None]:
article = dataset[4]["article"]

draft, refined = refine_summary(article)

print("Draft Redundancy:", redundancy_score(draft))
print("Refined Redundancy:", redundancy_score(refined))


Draft Redundancy: 0.2676767676767677
Refined Redundancy: 0.0


In [None]:
def clarion_summary(article):
    draft, refined = refine_summary(article)
    return {
        "draft": draft,
        "refined": refined,
        "fact_score": fact_score(article, refined),
        "redundancy": redundancy_score(refined)
    }


In [None]:
output = clarion_summary(dataset[5]["article"])

print("FINAL SUMMARY:\n", output["refined"])
print("\nFactScore:", output["fact_score"])
print("Redundancy:", output["redundancy"])


FINAL SUMMARY:
 world - wide , infertility affects 1015% of couples who are trying to conceive , and about 15% of these cases are caused by male factors , which affect 1 out of 20 men in the general population . Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â Â

FactScore: 1.0
Redundancy: 0.7272727272727273


In [None]:
!pip install -q evaluate bert-score rouge-score



[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/84.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import evaluate

rouge = evaluate.load("rouge")
bertscore = evaluate.load("bertscore")


Downloading builder script: 0.00B [00:00, ?B/s]

Downloading builder script: 0.00B [00:00, ?B/s]

In [None]:
def evaluate_summary(pred, ref, doc):
    r = rouge.compute(predictions=[pred], references=[ref])
    b = bertscore.compute(
        predictions=[pred],
        references=[ref],
        lang="en"
    )

    return {
        "rougeL": r["rougeL"],
        "bertscore": sum(b["f1"]) / len(b["f1"]),
        "factscore": fact_score(doc, pred),
        "redundancy": redundancy_score(pred)
    }


In [None]:
results = []

for i in range(10):
    doc = dataset[i]["article"]
    ref = dataset[i]["abstract"]

    output = clarion_summary(doc)
    scores = evaluate_summary(output["refined"], ref, doc)

    results.append(scores)

results


config.json:   0%|          | 0.00/482 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.42G [00:00<?, ?B/s]

Loading weights:   0%|          | 0/389 [00:00<?, ?it/s]

RobertaModel LOAD REPORT from: roberta-large
Key                             | Status     | 
--------------------------------+------------+-
lm_head.dense.bias              | UNEXPECTED | 
lm_head.bias                    | UNEXPECTED | 
lm_head.layer_norm.bias         | UNEXPECTED | 
lm_head.layer_norm.weight       | UNEXPECTED | 
lm_head.dense.weight            | UNEXPECTED | 
roberta.embeddings.position_ids | UNEXPECTED | 
pooler.dense.bias               | MISSING    | 
pooler.dense.weight             | MISSING    | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.
- MISSING	:those params were newly initialized because missing from the checkpoint. Consider training on your downstream task.


[{'rougeL': np.float64(0.1557788944723618),
  'bertscore': 0.8229395151138306,
  'factscore': 1.0,
  'redundancy': 0.3165829145728643},
 {'rougeL': np.float64(0.1601489757914339),
  'bertscore': 0.8282485008239746,
  'factscore': 0.9375,
  'redundancy': 0.09777777777777778},
 {'rougeL': np.float64(0.03846153846153846),
  'bertscore': 0.8038714528083801,
  'factscore': 1.0,
  'redundancy': 0.0},
 {'rougeL': np.float64(0.2),
  'bertscore': 0.8158096671104431,
  'factscore': 0.8333333333333334,
  'redundancy': 0.8421052631578947},
 {'rougeL': np.float64(0.08866995073891626),
  'bertscore': 0.8350473046302795,
  'factscore': 0.96875,
  'redundancy': 0.0},
 {'rougeL': np.float64(0.08290155440414508),
  'bertscore': 0.7906744480133057,
  'factscore': 1.0,
  'redundancy': 0.0},
 {'rougeL': np.float64(0.13953488372093026),
  'bertscore': 0.8255873322486877,
  'factscore': 1.0,
  'redundancy': 0.8669724770642202},
 {'rougeL': np.float64(0.1588089330024814),
  'bertscore': 0.8153097629547119,
  

In [None]:
training_args = TrainingArguments(
    output_dir="./outputs",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_steps=200,   # <-- IMPORTANT
    logging_steps=10,
    save_steps=200,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset
)

trainer.train()


Step,Training Loss
10,17.259215
20,19.537303
30,18.547823
40,17.585991
50,14.860515
60,12.779249
70,13.124251
80,13.244727
90,12.850232
100,11.486486


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

TrainOutput(global_step=200, training_loss=13.0520845413208, metrics={'train_runtime': 391.5584, 'train_samples_per_second': 2.043, 'train_steps_per_second': 0.511, 'total_flos': 1459569976934400.0, 'train_loss': 13.0520845413208, 'epoch': 0.6672226855713094})

In [None]:
def led_baseline(article):
    return generate_summary(article)


In [None]:
baseline_results = []
clarion_results = []

for i in range(10):
    doc = dataset[i]["article"]
    ref = dataset[i]["abstract"]

    # Baseline
    base_summary = led_baseline(doc)
    base_scores = evaluate_summary(base_summary, ref, doc)
    baseline_results.append(base_scores)

    # CLARION
    clarion_out = clarion_summary(doc)
    clarion_scores = evaluate_summary(
        clarion_out["refined"], ref, doc
    )
    clarion_results.append(clarion_scores)


In [None]:
import numpy as np

def average_metrics(results):
    return {
        k: float(np.mean([r[k] for r in results]))
        for k in results[0]
    }

avg_baseline = average_metrics(baseline_results)
avg_clarion = average_metrics(clarion_results)

avg_baseline, avg_clarion


({'rougeL': 0.22390288442045017,
  'bertscore': 0.8457440674304962,
  'factscore': 0.9740197675681547,
  'redundancy': 0.5038985337421992},
 {'rougeL': 0.2384759368038593,
  'bertscore': 0.849017721414566,
  'factscore': 0.9721298099140652,
  'redundancy': 0.5604867877376934})

In [None]:
training_args = TrainingArguments(
    output_dir="./ckpt_pubmed",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_steps=1000,
    fp16=True,
    save_steps=1000,
    report_to="none"
)


In [None]:
training_args = TrainingArguments(
    output_dir="./ckpt_clinicaltrials",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    max_steps=500,
    fp16=True,
    save_steps=500,
    report_to="none"
)


In [None]:
training_args = TrainingArguments(
    output_dir="./ckpt_mediqa",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    max_steps=400,
    fp16=True,
    save_steps=400,
    report_to="none"
)


In [None]:
from datasets import load_dataset

pubmed = load_dataset(
    "ccdv/pubmed-summarization",
    split="train[:5%]"
)

print(len(pubmed))


5996


In [None]:
sample = pubmed[0]
print(sample.keys())
print(sample["article"][:300])
print(sample["abstract"][:200])


dict_keys(['article', 'abstract'])
a recent systematic analysis showed that in 2011 , 314 ( 296 - 331 ) million children younger than 5 years were mildly , moderately or severely stunted and 258 ( 240 - 274 ) million were mildly , moderately or severely underweight in the developing countries . 
 in iran a study among 752 high school
background : the present study was carried out to assess the effects of community nutrition intervention based on advocacy approach on malnutrition status among school - aged children in shiraz , iran


In [None]:
pubmed_tok = pubmed.map(
    preprocess,
    batched=True,
    remove_columns=pubmed.column_names
)


Map:   0%|          | 0/5996 [00:00<?, ? examples/s]

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./ckpt_pubmed",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=2e-5,
    max_steps=1000,
    fp16=True,
    logging_steps=25,
    save_steps=1000,
    report_to="none"
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=pubmed_tok
)

trainer.train()


Step,Training Loss
25,10.566172
50,10.347302
75,9.4978
100,9.142291
125,8.851495
150,8.421496
175,9.198519
200,8.525671
225,8.265666
250,8.418966


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

TrainOutput(global_step=1000, training_loss=8.387555068969727, metrics={'train_runtime': 1503.1831, 'train_samples_per_second': 2.661, 'train_steps_per_second': 0.665, 'total_flos': 7297849884672000.0, 'train_loss': 8.387555068969727, 'epoch': 0.66711140760507})

In [None]:
trainer.save_model("./ckpt_pubmed")
tokenizer.save_pretrained("./ckpt_pubmed")


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

('./ckpt_pubmed/tokenizer_config.json', './ckpt_pubmed/tokenizer.json')

In [None]:
!unzip CHQA-Corpus-1.0.zip -d chqa



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-121846105.xml.txt  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-121846335.xml.ann  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-121846335.xml.txt  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-122818445.xml.ann  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-122818445.xml.txt  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-122827105.xml.ann  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-122827105.xml.txt  
  inflating: chqa/CHQA-Corpus-1.0/CHQA-email/1720_Unadjudicated/RawWithConfidenceEstimates/Ann6/1-122827772.x

In [None]:
!ls chqa


CHQA-Corpus-1.0


In [None]:
!ls chqa/CHQA-Corpus-1.0


calculations.xlsx  CHQA-email  CHQA-web  Guidelines  readme.txt


In [None]:
!ls chqa/CHQA-Corpus-1.0/CHQA-email


1720_Unadjudicated  20_Practice  annotation.conf  visual.conf


In [None]:
!ls chqa/CHQA-Corpus-1.0/CHQA-web


annotation.conf   What.c.0.ann	    Where.c.48.ann    Who.r.41.ann
How100.c.0.ann	  What.c.0.txt	    Where.c.48.txt    Who.r.41.txt
How100.c.0.txt	  What.c.11.ann     Where.c.49.ann    Who.r.49.ann
How100.c.1.ann	  What.c.11.txt     Where.c.49.txt    Who.r.49.txt
How100.c.1.txt	  What.c.12.ann     Where.c.4.ann     Who.r.52.ann
How100.c.2.ann	  What.c.12.txt     Where.c.4.txt     Who.r.52.txt
How100.c.2.txt	  What.c.13.ann     Where.c.51.ann    Who.r.53.ann
How100.c.4.ann	  What.c.13.txt     Where.c.51.txt    Who.r.53.txt
How100.c.4.txt	  What.c.14.ann     Where.c.52.ann    Who.r.57.ann
How100.c.5.ann	  What.c.14.txt     Where.c.52.txt    Who.r.57.txt
How100.c.5.txt	  What.c.15.ann     Where.c.54.ann    Who.r.58.ann
How100.r.1.ann	  What.c.15.txt     Where.c.54.txt    Who.r.58.txt
How100.r.1.txt	  What.c.16.ann     Where.c.55.ann    Who.r.8.ann
How100.r.2.ann	  What.c.16.txt     Where.c.55.txt    Who.r.8.txt
How100.r.2.txt	  What.c.17.ann     Where.c.56.ann    Who.r.9.ann
How100.r.3.ann

In [None]:
import os

def load_chqa_pairs(folder):
    questions = {}
    answers = {}

    for fname in os.listdir(folder):
        if not fname.endswith(".txt"):
            continue

        path = os.path.join(folder, fname)
        with open(path, "r", errors="ignore") as f:
            text = f.read().strip()

        if ".c." in fname:
            key = fname.replace(".c.", ".")
            questions[key] = text

        elif ".r." in fname:
            key = fname.replace(".r.", ".")
            answers[key] = text

    pairs = []
    for key in questions:
        if key in answers:
            pairs.append({
                "article": questions[key],
                "abstract": answers[key]
            })

    return pairs


In [None]:
email_path = "chqa/CHQA-Corpus-1.0/CHQA-email"
web_path   = "chqa/CHQA-Corpus-1.0/CHQA-web"

email_pairs = load_chqa_pairs(email_path)
web_pairs   = load_chqa_pairs(web_path)

all_pairs = email_pairs + web_pairs

print("Total QA pairs:", len(all_pairs))


Total QA pairs: 333


In [None]:
from datasets import Dataset

chqa_ds = Dataset.from_list(all_pairs)
chqa_ds


Dataset({
    features: ['article', 'abstract'],
    num_rows: 333
})

In [None]:
def preprocess(batch):
    model_inputs = tokenizer(
        batch["article"],
        truncation=True,
        padding="max_length",
        max_length=2048
    )

    labels = tokenizer(
        text_target=batch["abstract"],
        truncation=True,
        padding="max_length",
        max_length=256
    )

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


In [None]:
chqa_tok = chqa_ds.map(
    preprocess,
    batched=True,
    remove_columns=chqa_ds.column_names
)


Map:   0%|          | 0/333 [00:00<?, ? examples/s]

In [None]:
!ls

chqa  CHQA-Corpus-1.0.zip  ckpt_pubmed	mediqa_osf.zip	outputs  sample_data


In [None]:
from transformers import LEDForConditionalGeneration

model = LEDForConditionalGeneration.from_pretrained("./ckpt_pubmed")
model = model.to(device)


Loading weights:   0%|          | 0/299 [00:00<?, ?it/s]



In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./ckpt_final_clarion",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    learning_rate=5e-6,
    max_steps=400,
    fp16=True,
    logging_steps=25,
    save_steps=400,
    report_to="none"
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=chqa_tok
)

trainer.train()


Step,Training Loss
25,1.129693
50,0.798188
75,0.697015
100,0.668031
125,0.674721
150,0.677479
175,0.611614
200,0.613806
225,0.637214
250,0.629604


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

TrainOutput(global_step=400, training_loss=0.667027804851532, metrics={'train_runtime': 635.1718, 'train_samples_per_second': 2.519, 'train_steps_per_second': 0.63, 'total_flos': 2897246404214784.0, 'train_loss': 0.667027804851532, 'epoch': 4.768768768768769})

In [None]:
mv ckpt_final_clarion clarion_runA


In [None]:
!ls clarion_runA


checkpoint-400	generation_config.json	tokenizer_config.json
config.json	model.safetensors	tokenizer.json


In [None]:
output_dir="./ckpt_final_clarion"


In [None]:
from transformers import LEDForConditionalGeneration, LEDTokenizerFast

model_B = LEDForConditionalGeneration.from_pretrained(
    "./ckpt_pubmed/checkpoint-1000",
    local_files_only=True
)

tokenizer = LEDTokenizerFast.from_pretrained(
    "allenai/led-base-16384",
    local_files_only=False
)

print("Checkpoint-1000 loaded successfully")


Loading weights:   0%|          | 0/299 [00:00<?, ?it/s]



Checkpoint-1000 loaded successfully


In [None]:
model_B.save_pretrained("clarion_runB")
tokenizer.save_pretrained("clarion_runB")

print("clarion_runB created from checkpoint-1000")


Writing model shards:   0%|          | 0/1 [00:00<?, ?it/s]

clarion_runB created from checkpoint-1000
