In [1]:
!pip install datasets



In [15]:
import re
import json
import pandas as pd
from datasets import Dataset
from transformers import BartTokenizer, BartForConditionalGeneration, Trainer, TrainingArguments

def clean_text(text):
    text = re.sub(r"[^a-zA-Z0-9,.:;()\-\s]", "", text)
    text = re.sub(r"\s+", " ", text)

    text = text.strip()
    return text


def json_to_summary(data):
    summary = f"Patient Name: {data['patient_name']}, referred by {data['referring_doctor']}, presented with {data['compliant']}, diagnosed with {data['diagnosis']}. "

    summary += "Therapy sessions focused on "
    sessions_summary = []
    for session in data['sessions_done']:
        activities = ', '.join([act['name'] for act in session['activities']])
        sessions_summary.append(f"{session['name']} ({session['count']} sessions, activities: {activities})")
    summary += ", ".join(sessions_summary) + ". "

    improvements = "; ".join([f"{improvement['improvement']} on {improvement['date']}" for improvement in data['improvements']])
    summary += f"Notable improvements observed were {improvements}. "

    vitals = data['vital_last_week']
    summary += f"Vital signs showed a heart rate of {vitals['heart_rate']} bpm and blood pressure of {vitals['blood_pressure']} mmHg. "

    procedures = "; ".join([f"{procedure['name']}" for procedure in data['procedures_done']])
    summary += f"Procedures performed included {procedures}. "

    milestones = data['milestones']
    summary += f"Milestones achieved include {milestones}."

    return summary


def json_to_dataset(data):
    text_pairs = []
    for item in data:
        input_text = clean_text(json.dumps(item))
        output_text = clean_text(json_to_summary(item))
        text_pairs.append({"input_text": input_text, "output_text": output_text})
    return Dataset.from_pandas(pd.DataFrame(text_pairs))

data = [
    {
        "patient_name": "Mr. Abc",
        "referring_doctor": "Dr. Xyz",
        "compliant": "Slurred speech and Weakness",
        "diagnosis": "ACUTE ISCHEMIC STROKE",
        "sessions_done": [
            {
                "date": "15-Jul-24",
                "name": "Strength Training",
                "count": 2,
                "activities": [
                    {"name": "Weight Cuff", "level": "Level 1"},
                    {"name": "Swiss Ball", "level": "Level 1"},
                    {"name": "Erogmeter", "level": "Level 1"}
                ]
            }
        ],
        "improvements": [
            {"date": "15-Jul-24", "improvement": "Foleys Catheter Removal"},
            {"date": "18-Jul-24", "improvement": "Sitting with support"},
            {"date": "21-Jul-24", "improvement": "Wheel chair mobilisation"}
        ],
        "milestones": "Tracstomy Tube Removed, Sitting without support",
        "vitals_admission": {
            "heart_rate": "70",
            "blood_pressure": "120/80",
            "temperature": "98.6",
            "respiratory_rate": 16
        },
        "vital_last_week": {
            "heart_rate": "80",
            "blood_pressure": "130/85",
            "temperature": "99.5",
            "respiratory_rate": 18
        },
        "vital_weekly": [
            {"date": "15-Jul-24", "heart_rate": "70", "blood_pressure": "120/80", "temperature": "98.6", "respiratory_rate": 16}
        ],
        "procedures_done": [
            {"date": "15-Jul-24", "name": "Wound Dressings"},
            {"date": "17-Jul-24", "name": "Tube Changing(NG,FC)"},
            {"date": "19-Jul-24", "name": "Nebulization"}
        ]
    }
]


dataset = json_to_dataset(data)

tokenizer = BartTokenizer.from_pretrained('facebook/bart-large')
model = BartForConditionalGeneration.from_pretrained('facebook/bart-large')


def preprocess_function(examples):
    inputs = examples['input_text']
    targets = examples['output_text']
    model_inputs = tokenizer(inputs, max_length=1024, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=1024, truncation=True)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

tokenized_dataset = dataset.map(preprocess_function, batched=True)


training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=3,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir='./logs',
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    eval_dataset=tokenized_dataset,
)

trainer.train()


'''def generate_summary_from_dataset(model, tokenizer, tokenized_dataset, index):
    example = tokenized_dataset[index]
    input_ids = example['input_ids']
    input_ids = torch.tensor(input_ids).unsqueeze(0)  # Convert to tensor and add batch dimension
    summary_ids = model.generate(input_ids, max_length=1024, min_length=100, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary'''
def generate_summary_from_dataset(model, tokenizer, tokenized_dataset, index):
    example = tokenized_dataset[index]
    input_ids = example['input_ids']
    input_ids = torch.tensor(input_ids).unsqueeze(0)


    input_ids = input_ids.to(model.device)

    summary_ids = model.generate(input_ids, max_length=1024, min_length=100, length_penalty=2.0, num_beams=4, early_stopping=True)
    summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
    return summary

index = 0
summary = generate_summary_from_dataset(model, tokenizer, tokenized_dataset, index)
print(summary)


Map:   0%|          | 0/1 [00:00<?, ? examples/s]



Step,Training Loss


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


patientname: Mr. Abc, referringdoctor: Dr. Xyz, compliant: Slurred speech and Weakness, diagnosis: ACUTE ISCHEMIC STROKE, sessionsdone: date: 15-Jul-24, name: Strength Training, count: 2, activities: name: Weight Cuff, level: Level 1, time: 1:30, time of day: 0:00, time spent: 1-2: 30 minutes, activity level: 1, activity: weight: Level: 1/2, activities level: 2/3/4/5/6/7/8/9/10/11/12/13/15/16/17/18/19/20/2021/21/22/2/23/24/2022/23-24/25/26/27/28/29/30/31/32/1/1-2/3, 2-3/2:1/2-1,3,4/3:4/4,5/5:6/1:3,6,7/4:5,8/8:7,9/9,10/10:12,11/11,12/12:13/14:15/15:16/16:17/17:18/20:21:22:23/22:24/23:25/24:26/26:27/29:28/30:31/31:32/33:34/34:35/34/35/36, vitalsadmission: heartrate: 70, bloodpressure: 12080, temperature: 98.6, respiratoryrate: 16, vitallastweek: heart: 70.5.5, respiratory: 18, vital, vital: 15/16, vit: 18.5/18, vitalweekly: date 14/14/15, vitan: 14/15.vitals: date 15/date 16/15-16/18.6/16-17.6: 16/18-19-20/19-21-22/22-23/25-26-28-30/27-28/28-29-31/30-31-01/01/02/03/04/02:Name: Mr Abc (FC