In [None]:
%pip install transformers torch

In [None]:
import json
import torch
from typing import List
from torch.utils.data import Dataset
from transformers import PreTrainedTokenizer
from transformers import GPT2LMHeadModel, GPT2Tokenizer, DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments

In [None]:
class DatasetFromJSON(Dataset):
    def __init__(self, data: List[List[str]], tokenizer: PreTrainedTokenizer, max_length=1024):
        self.tokenizer = tokenizer
        self.input_data = []
        self.max_length = max_length
        for conversation in data:
            for i in range(len(conversation) - 1):
                input_pair = (conversation[i], conversation[i + 1])
                encoded_pair = tokenizer.encode(input_pair[0], input_pair[1], add_special_tokens=True, truncation=True, max_length=self.max_length, padding="max_length")
                self.input_data.append(encoded_pair)

    def __len__(self):
        return len(self.input_data)

    def __getitem__(self, idx):
        example = self.input_data[idx]
        return example

In [None]:
# This is done to process the medical conversations JSON file
def read_and_process_json(file_path: str) -> List[List[str]]:
    print("Reading...")
    with open(file_path, "r", encoding="utf-8") as file:
        data = json.load(file)
    return data

def train_dialo_gpt(model_name, conversations, output_dir, epochs=1):
    print("Training...")
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    model = GPT2LMHeadModel.from_pretrained(model_name)
    model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) 
    tokenizer.pad_token = tokenizer.eos_token
    train_dataset = DatasetFromJSON(conversations, tokenizer)
    
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=False,
        pad_to_multiple_of=8
    )

    training_args = TrainingArguments(
        output_dir=output_dir,
        overwrite_output_dir=True,
        num_train_epochs=epochs,
        per_device_train_batch_size=2,
        save_steps=100,
        save_total_limit=3,
        logging_steps=100,
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_dataset,
    )
    trainer.train()

In [None]:
training_file = "./train_data.json"
model_output = "./models"

In [None]:
def train_medical(): 
    conversations = read_and_process_json(training_file)
    train_dialo_gpt("microsoft/DialoGPT-large", conversations, model_output, 5)

In [None]:
train_medical()