## load model

In [1]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, DataCollatorForSeq2Seq
last_checkpoint_34000 = "./third-results/checkpoint-4000"

tokenizer = T5Tokenizer.from_pretrained(last_checkpoint_34000)
model = T5ForConditionalGeneration.from_pretrained(last_checkpoint_34000, device_map="auto")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## load dataset

In [2]:
from datasets import load_dataset


ds = load_dataset("keivalya/MedQuad-MedicalQnADataset")
ds

DatasetDict({
    train: Dataset({
        features: ['qtype', 'Question', 'Answer'],
        num_rows: 16407
    })
})

## split the dataset

In [3]:
ds=ds['train'].train_test_split(test_size=0.1)
ds

DatasetDict({
    train: Dataset({
        features: ['qtype', 'Question', 'Answer'],
        num_rows: 14766
    })
    test: Dataset({
        features: ['qtype', 'Question', 'Answer'],
        num_rows: 1641
    })
})

## processing dataset

In [4]:
prefix = "Please answer this medical related question: "

# Define the preprocessing function

def preprocess_function(examples):
   inputs = [prefix + doc for doc in examples["Question"]]
   model_inputs = tokenizer(inputs, max_length=128, truncation=True)
  
   labels = tokenizer(text_target=examples["Answer"], 
                      max_length=512,         
                      truncation=True)

   model_inputs["labels"] = labels["input_ids"]
   return model_inputs

tokenized_dataset = ds.map(preprocess_function, batched=True)

Map: 100%|██████████| 14766/14766 [00:10<00:00, 1445.29 examples/s]
Map: 100%|██████████| 1641/1641 [00:01<00:00, 1475.59 examples/s]


## compute_metrics

In [5]:
import nltk
import evaluate
import numpy as np

nltk.download("punkt", quiet=True)
metric = evaluate.load("rouge")

def compute_metrics(eval_preds):
   preds, labels = eval_preds

   # decode preds and labels
   labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
   decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
   decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

   # rougeLSum expects newline after each sentence
   decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
   decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]

   result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
  
   return result

## fine-tuning

In [8]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
# Define training arguments
training_args = Seq2SeqTrainingArguments(
    output_dir="./fourth-results",
    evaluation_strategy="epoch",
    learning_rate=5e-5,  
    per_device_train_batch_size=8,  
    per_device_eval_batch_size=4, 
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=10,
    save_total_limit=5,
    predict_with_generate=True,
    generation_max_length=200,
    push_to_hub=False
   
)

# Initialize the Trainer
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()



  0%|          | 10/5538 [00:10<1:58:16,  1.28s/it]

{'loss': 1.6471, 'grad_norm': 0.6815581321716309, 'learning_rate': 4.99097146984471e-05, 'epoch': 0.01}


  0%|          | 20/5538 [00:20<1:48:17,  1.18s/it]

{'loss': 1.8984, 'grad_norm': 0.8914475440979004, 'learning_rate': 4.981942939689419e-05, 'epoch': 0.01}


  1%|          | 30/5538 [00:34<1:42:32,  1.12s/it]

{'loss': 1.9592, 'grad_norm': 0.8586159944534302, 'learning_rate': 4.972914409534128e-05, 'epoch': 0.02}


  1%|          | 40/5538 [00:43<1:13:27,  1.25it/s]

{'loss': 1.6447, 'grad_norm': 0.6837711930274963, 'learning_rate': 4.963885879378837e-05, 'epoch': 0.02}


  1%|          | 50/5538 [00:51<1:25:00,  1.08it/s]

{'loss': 1.6331, 'grad_norm': 0.750500500202179, 'learning_rate': 4.9548573492235465e-05, 'epoch': 0.03}


  1%|          | 60/5538 [00:58<1:21:13,  1.12it/s]

{'loss': 1.7649, 'grad_norm': 0.8571430444717407, 'learning_rate': 4.945828819068256e-05, 'epoch': 0.03}


  1%|▏         | 70/5538 [01:04<54:12,  1.68it/s]  

{'loss': 2.2469, 'grad_norm': 1.2145253419876099, 'learning_rate': 4.936800288912965e-05, 'epoch': 0.04}


  1%|▏         | 80/5538 [01:13<1:32:58,  1.02s/it]

{'loss': 1.8709, 'grad_norm': 1.0647796392440796, 'learning_rate': 4.9277717587576744e-05, 'epoch': 0.04}


  2%|▏         | 90/5538 [01:17<41:05,  2.21it/s]  

{'loss': 1.8941, 'grad_norm': 0.8765599131584167, 'learning_rate': 4.918743228602384e-05, 'epoch': 0.05}


  2%|▏         | 100/5538 [01:24<1:03:14,  1.43it/s]

{'loss': 1.821, 'grad_norm': 0.8860837817192078, 'learning_rate': 4.9097146984470935e-05, 'epoch': 0.05}


  2%|▏         | 110/5538 [01:29<1:04:50,  1.40it/s]

{'loss': 1.864, 'grad_norm': 0.6963024139404297, 'learning_rate': 4.900686168291802e-05, 'epoch': 0.06}


  2%|▏         | 120/5538 [01:37<1:10:51,  1.27it/s]

{'loss': 1.9927, 'grad_norm': 0.9631701707839966, 'learning_rate': 4.891657638136511e-05, 'epoch': 0.07}


  2%|▏         | 130/5538 [01:42<59:49,  1.51it/s]  

{'loss': 1.7978, 'grad_norm': 1.0837020874023438, 'learning_rate': 4.882629107981221e-05, 'epoch': 0.07}


  3%|▎         | 140/5538 [01:52<1:30:05,  1.00s/it]

{'loss': 1.929, 'grad_norm': 1.0039842128753662, 'learning_rate': 4.87360057782593e-05, 'epoch': 0.08}


  3%|▎         | 150/5538 [01:59<1:10:03,  1.28it/s]

{'loss': 1.6342, 'grad_norm': 1.105271816253662, 'learning_rate': 4.86457204767064e-05, 'epoch': 0.08}


  3%|▎         | 160/5538 [02:05<47:12,  1.90it/s]  

{'loss': 1.9198, 'grad_norm': 0.9102680683135986, 'learning_rate': 4.8555435175153486e-05, 'epoch': 0.09}


  3%|▎         | 170/5538 [02:14<1:50:29,  1.24s/it]

{'loss': 1.7038, 'grad_norm': 1.058592438697815, 'learning_rate': 4.846514987360058e-05, 'epoch': 0.09}


  3%|▎         | 180/5538 [02:23<1:13:43,  1.21it/s]

{'loss': 1.8843, 'grad_norm': 1.1375313997268677, 'learning_rate': 4.837486457204768e-05, 'epoch': 0.1}


  3%|▎         | 190/5538 [02:33<1:16:09,  1.17it/s]

{'loss': 1.871, 'grad_norm': 0.7807405591011047, 'learning_rate': 4.8284579270494765e-05, 'epoch': 0.1}


  4%|▎         | 200/5538 [02:41<1:04:21,  1.38it/s]

{'loss': 1.8063, 'grad_norm': 0.8816958069801331, 'learning_rate': 4.8194293968941854e-05, 'epoch': 0.11}


  4%|▍         | 210/5538 [02:50<1:39:28,  1.12s/it]

{'loss': 1.8383, 'grad_norm': 1.2237052917480469, 'learning_rate': 4.810400866738895e-05, 'epoch': 0.11}


  4%|▍         | 220/5538 [02:57<48:15,  1.84it/s]  

{'loss': 1.9797, 'grad_norm': 1.0917038917541504, 'learning_rate': 4.8013723365836044e-05, 'epoch': 0.12}


  4%|▍         | 230/5538 [03:06<1:16:46,  1.15it/s]

{'loss': 1.8336, 'grad_norm': 0.8610579371452332, 'learning_rate': 4.792343806428314e-05, 'epoch': 0.12}


  4%|▍         | 240/5538 [03:15<51:10,  1.73it/s]  

{'loss': 1.7903, 'grad_norm': 0.9725875854492188, 'learning_rate': 4.7833152762730235e-05, 'epoch': 0.13}


  5%|▍         | 250/5538 [03:23<1:02:24,  1.41it/s]

{'loss': 1.9911, 'grad_norm': 1.0571547746658325, 'learning_rate': 4.774286746117732e-05, 'epoch': 0.14}


  5%|▍         | 260/5538 [03:34<1:18:38,  1.12it/s]

{'loss': 1.7662, 'grad_norm': 0.7394495010375977, 'learning_rate': 4.765258215962441e-05, 'epoch': 0.14}


  5%|▍         | 270/5538 [03:40<1:07:56,  1.29it/s]

{'loss': 1.9107, 'grad_norm': 0.630281388759613, 'learning_rate': 4.756229685807151e-05, 'epoch': 0.15}


  5%|▌         | 280/5538 [03:46<42:38,  2.06it/s]  

{'loss': 1.8296, 'grad_norm': 0.8466852903366089, 'learning_rate': 4.74720115565186e-05, 'epoch': 0.15}


  5%|▌         | 290/5538 [03:54<1:08:12,  1.28it/s]

{'loss': 1.8828, 'grad_norm': 0.9840760231018066, 'learning_rate': 4.738172625496569e-05, 'epoch': 0.16}


  5%|▌         | 300/5538 [04:02<1:30:35,  1.04s/it]

{'loss': 1.6746, 'grad_norm': 1.0164296627044678, 'learning_rate': 4.7291440953412786e-05, 'epoch': 0.16}


  6%|▌         | 310/5538 [04:12<1:29:02,  1.02s/it]

{'loss': 1.6817, 'grad_norm': 0.9772088527679443, 'learning_rate': 4.720115565185988e-05, 'epoch': 0.17}


  6%|▌         | 320/5538 [04:20<43:29,  2.00it/s]  

{'loss': 1.9799, 'grad_norm': 0.8360973596572876, 'learning_rate': 4.711087035030698e-05, 'epoch': 0.17}


  6%|▌         | 330/5538 [04:24<34:54,  2.49it/s]

{'loss': 1.9313, 'grad_norm': 0.7448869347572327, 'learning_rate': 4.7020585048754065e-05, 'epoch': 0.18}


  6%|▌         | 340/5538 [04:33<1:23:12,  1.04it/s]

{'loss': 1.7664, 'grad_norm': 0.8633103966712952, 'learning_rate': 4.6930299747201154e-05, 'epoch': 0.18}


  6%|▋         | 350/5538 [04:44<1:35:25,  1.10s/it]

{'loss': 1.7923, 'grad_norm': 0.7415929436683655, 'learning_rate': 4.684001444564825e-05, 'epoch': 0.19}


  7%|▋         | 360/5538 [04:55<1:46:24,  1.23s/it]

{'loss': 1.7635, 'grad_norm': 1.019365668296814, 'learning_rate': 4.6749729144095344e-05, 'epoch': 0.2}


  7%|▋         | 370/5538 [05:03<55:12,  1.56it/s]  

{'loss': 1.8141, 'grad_norm': 1.0299142599105835, 'learning_rate': 4.665944384254244e-05, 'epoch': 0.2}


  7%|▋         | 380/5538 [05:10<59:14,  1.45it/s]  

{'loss': 1.6017, 'grad_norm': 0.853905975818634, 'learning_rate': 4.656915854098953e-05, 'epoch': 0.21}


  7%|▋         | 390/5538 [05:20<1:36:58,  1.13s/it]

{'loss': 1.9144, 'grad_norm': 0.9976400136947632, 'learning_rate': 4.647887323943662e-05, 'epoch': 0.21}


  7%|▋         | 400/5538 [05:26<42:44,  2.00it/s]  

{'loss': 1.8972, 'grad_norm': 0.9246968626976013, 'learning_rate': 4.638858793788371e-05, 'epoch': 0.22}


  7%|▋         | 410/5538 [05:33<1:06:40,  1.28it/s]

{'loss': 1.77, 'grad_norm': 1.171756386756897, 'learning_rate': 4.629830263633081e-05, 'epoch': 0.22}


  8%|▊         | 420/5538 [05:43<1:26:47,  1.02s/it]

{'loss': 1.6806, 'grad_norm': 1.0541118383407593, 'learning_rate': 4.6208017334777896e-05, 'epoch': 0.23}


  8%|▊         | 430/5538 [05:50<54:55,  1.55it/s]  

{'loss': 1.817, 'grad_norm': 0.8242820501327515, 'learning_rate': 4.611773203322499e-05, 'epoch': 0.23}


  8%|▊         | 440/5538 [05:57<1:07:04,  1.27it/s]

{'loss': 1.7633, 'grad_norm': 0.6329175233840942, 'learning_rate': 4.6027446731672086e-05, 'epoch': 0.24}


  8%|▊         | 450/5538 [06:04<1:12:42,  1.17it/s]

{'loss': 1.8852, 'grad_norm': 1.025809407234192, 'learning_rate': 4.593716143011918e-05, 'epoch': 0.24}


  8%|▊         | 460/5538 [06:11<47:14,  1.79it/s]  

{'loss': 1.5864, 'grad_norm': 1.4477018117904663, 'learning_rate': 4.584687612856628e-05, 'epoch': 0.25}


  8%|▊         | 470/5538 [06:19<1:02:59,  1.34it/s]

{'loss': 1.8104, 'grad_norm': 0.792488694190979, 'learning_rate': 4.5756590827013365e-05, 'epoch': 0.25}


  9%|▊         | 480/5538 [06:27<55:15,  1.53it/s]  

{'loss': 1.903, 'grad_norm': 1.1075599193572998, 'learning_rate': 4.5666305525460454e-05, 'epoch': 0.26}


  9%|▉         | 490/5538 [06:37<1:26:27,  1.03s/it]

{'loss': 1.7159, 'grad_norm': 0.9168406128883362, 'learning_rate': 4.557602022390755e-05, 'epoch': 0.27}


  9%|▉         | 500/5538 [06:49<1:42:09,  1.22s/it]

{'loss': 1.8683, 'grad_norm': 0.8399244546890259, 'learning_rate': 4.5485734922354644e-05, 'epoch': 0.27}


  9%|▉         | 510/5538 [06:56<56:27,  1.48it/s]  

{'loss': 2.0352, 'grad_norm': 0.9292173981666565, 'learning_rate': 4.539544962080173e-05, 'epoch': 0.28}


  9%|▉         | 520/5538 [07:05<1:09:00,  1.21it/s]

{'loss': 1.7849, 'grad_norm': 1.2919830083847046, 'learning_rate': 4.530516431924883e-05, 'epoch': 0.28}


 10%|▉         | 530/5538 [07:19<1:35:48,  1.15s/it]

{'loss': 1.5185, 'grad_norm': 1.0103663206100464, 'learning_rate': 4.521487901769592e-05, 'epoch': 0.29}


 10%|▉         | 540/5538 [07:25<50:07,  1.66it/s]  

{'loss': 1.8582, 'grad_norm': 1.0337111949920654, 'learning_rate': 4.512459371614302e-05, 'epoch': 0.29}


 10%|▉         | 550/5538 [07:30<37:31,  2.22it/s]

{'loss': 1.6711, 'grad_norm': 1.0402787923812866, 'learning_rate': 4.503430841459011e-05, 'epoch': 0.3}


 10%|█         | 560/5538 [07:36<58:24,  1.42it/s]

{'loss': 1.824, 'grad_norm': 1.0783880949020386, 'learning_rate': 4.4944023113037196e-05, 'epoch': 0.3}


 10%|█         | 570/5538 [07:44<1:02:40,  1.32it/s]

{'loss': 1.5984, 'grad_norm': 1.0584968328475952, 'learning_rate': 4.485373781148429e-05, 'epoch': 0.31}


 10%|█         | 580/5538 [07:49<39:03,  2.12it/s]  

{'loss': 1.5746, 'grad_norm': 1.0708978176116943, 'learning_rate': 4.4763452509931386e-05, 'epoch': 0.31}


 11%|█         | 590/5538 [07:58<1:12:08,  1.14it/s]

{'loss': 1.9093, 'grad_norm': 1.009237289428711, 'learning_rate': 4.467316720837848e-05, 'epoch': 0.32}


 11%|█         | 600/5538 [08:06<41:35,  1.98it/s]  

{'loss': 1.8209, 'grad_norm': 1.0297877788543701, 'learning_rate': 4.458288190682557e-05, 'epoch': 0.33}


 11%|█         | 610/5538 [08:14<1:18:30,  1.05it/s]

{'loss': 1.6644, 'grad_norm': 0.8634269833564758, 'learning_rate': 4.4492596605272665e-05, 'epoch': 0.33}


 11%|█         | 620/5538 [08:26<1:41:29,  1.24s/it]

{'loss': 1.8339, 'grad_norm': 0.8431247472763062, 'learning_rate': 4.4402311303719754e-05, 'epoch': 0.34}


 11%|█▏        | 630/5538 [08:37<1:23:23,  1.02s/it]

{'loss': 1.9133, 'grad_norm': 1.063477873802185, 'learning_rate': 4.431202600216685e-05, 'epoch': 0.34}


 12%|█▏        | 640/5538 [08:44<45:28,  1.80it/s]  

{'loss': 1.66, 'grad_norm': 0.9915026426315308, 'learning_rate': 4.422174070061394e-05, 'epoch': 0.35}


 12%|█▏        | 650/5538 [08:50<1:00:13,  1.35it/s]

{'loss': 1.691, 'grad_norm': 0.9447178244590759, 'learning_rate': 4.413145539906103e-05, 'epoch': 0.35}


 12%|█▏        | 660/5538 [08:57<43:01,  1.89it/s]  

{'loss': 1.8875, 'grad_norm': 1.7987275123596191, 'learning_rate': 4.404117009750813e-05, 'epoch': 0.36}


 12%|█▏        | 670/5538 [09:05<1:20:48,  1.00it/s]

{'loss': 1.6484, 'grad_norm': 0.9506717324256897, 'learning_rate': 4.395088479595522e-05, 'epoch': 0.36}


 12%|█▏        | 680/5538 [09:11<1:09:28,  1.17it/s]

{'loss': 1.7409, 'grad_norm': 0.7467921376228333, 'learning_rate': 4.386059949440232e-05, 'epoch': 0.37}


 12%|█▏        | 690/5538 [09:18<57:12,  1.41it/s]  

{'loss': 1.8262, 'grad_norm': 0.8157848715782166, 'learning_rate': 4.377031419284941e-05, 'epoch': 0.37}


 13%|█▎        | 700/5538 [09:27<1:03:41,  1.27it/s]

{'loss': 1.8127, 'grad_norm': 0.8204064965248108, 'learning_rate': 4.3680028891296496e-05, 'epoch': 0.38}


 13%|█▎        | 710/5538 [09:37<1:19:01,  1.02it/s]

{'loss': 1.8326, 'grad_norm': 0.6814621090888977, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}


 13%|█▎        | 720/5538 [09:42<46:27,  1.73it/s]  

{'loss': 1.7131, 'grad_norm': 0.6976842284202576, 'learning_rate': 4.3499458288190686e-05, 'epoch': 0.39}


 13%|█▎        | 730/5538 [09:48<43:46,  1.83it/s]

{'loss': 1.6653, 'grad_norm': 1.1761175394058228, 'learning_rate': 4.3409172986637775e-05, 'epoch': 0.4}


 13%|█▎        | 740/5538 [09:54<1:02:32,  1.28it/s]

{'loss': 1.6897, 'grad_norm': 0.9046758413314819, 'learning_rate': 4.331888768508487e-05, 'epoch': 0.4}


 14%|█▎        | 750/5538 [09:58<33:50,  2.36it/s]  

{'loss': 1.7646, 'grad_norm': 0.9503940939903259, 'learning_rate': 4.3228602383531965e-05, 'epoch': 0.41}


 14%|█▎        | 760/5538 [10:03<49:16,  1.62it/s]

{'loss': 1.9821, 'grad_norm': 0.9977975487709045, 'learning_rate': 4.313831708197906e-05, 'epoch': 0.41}


 14%|█▍        | 770/5538 [10:11<1:18:23,  1.01it/s]

{'loss': 1.9034, 'grad_norm': 0.6804600358009338, 'learning_rate': 4.304803178042615e-05, 'epoch': 0.42}


 14%|█▍        | 780/5538 [10:22<1:39:20,  1.25s/it]

{'loss': 1.8316, 'grad_norm': 0.9937782287597656, 'learning_rate': 4.295774647887324e-05, 'epoch': 0.42}


 14%|█▍        | 790/5538 [10:31<1:04:48,  1.22it/s]

{'loss': 1.8007, 'grad_norm': 0.8556992411613464, 'learning_rate': 4.286746117732033e-05, 'epoch': 0.43}


 14%|█▍        | 800/5538 [10:38<44:49,  1.76it/s]  

{'loss': 1.7604, 'grad_norm': 1.0011152029037476, 'learning_rate': 4.277717587576743e-05, 'epoch': 0.43}


 15%|█▍        | 810/5538 [10:45<49:05,  1.61it/s]  

{'loss': 1.6318, 'grad_norm': 1.0097262859344482, 'learning_rate': 4.268689057421452e-05, 'epoch': 0.44}


 15%|█▍        | 820/5538 [10:51<52:21,  1.50it/s]

{'loss': 1.7379, 'grad_norm': 1.016278624534607, 'learning_rate': 4.259660527266161e-05, 'epoch': 0.44}


 15%|█▍        | 830/5538 [10:56<37:01,  2.12it/s]

{'loss': 1.9177, 'grad_norm': 0.7414392828941345, 'learning_rate': 4.250631997110871e-05, 'epoch': 0.45}


 15%|█▌        | 840/5538 [11:04<1:03:27,  1.23it/s]

{'loss': 1.8541, 'grad_norm': 0.8735577464103699, 'learning_rate': 4.2416034669555795e-05, 'epoch': 0.46}


 15%|█▌        | 850/5538 [11:13<59:56,  1.30it/s]  

{'loss': 1.8998, 'grad_norm': 0.9635006785392761, 'learning_rate': 4.232574936800289e-05, 'epoch': 0.46}


 16%|█▌        | 860/5538 [11:18<32:52,  2.37it/s]

{'loss': 1.7278, 'grad_norm': 0.9373649954795837, 'learning_rate': 4.223546406644998e-05, 'epoch': 0.47}


 16%|█▌        | 870/5538 [11:27<1:25:33,  1.10s/it]

{'loss': 1.684, 'grad_norm': 0.5763220191001892, 'learning_rate': 4.2145178764897075e-05, 'epoch': 0.47}


 16%|█▌        | 880/5538 [11:33<44:26,  1.75it/s]  

{'loss': 1.6401, 'grad_norm': 1.209186315536499, 'learning_rate': 4.205489346334417e-05, 'epoch': 0.48}


 16%|█▌        | 890/5538 [11:42<1:09:22,  1.12it/s]

{'loss': 1.7383, 'grad_norm': 0.9575021266937256, 'learning_rate': 4.1964608161791265e-05, 'epoch': 0.48}


 16%|█▋        | 900/5538 [11:51<59:57,  1.29it/s]  

{'loss': 1.8635, 'grad_norm': 1.1302375793457031, 'learning_rate': 4.187432286023836e-05, 'epoch': 0.49}


 16%|█▋        | 910/5538 [11:58<46:16,  1.67it/s]

{'loss': 1.8091, 'grad_norm': 1.1046762466430664, 'learning_rate': 4.178403755868545e-05, 'epoch': 0.49}


 17%|█▋        | 920/5538 [12:04<40:12,  1.91it/s]

{'loss': 1.7364, 'grad_norm': 1.0219578742980957, 'learning_rate': 4.169375225713254e-05, 'epoch': 0.5}


 17%|█▋        | 930/5538 [12:08<43:48,  1.75it/s]

{'loss': 1.6447, 'grad_norm': 0.6920832395553589, 'learning_rate': 4.160346695557963e-05, 'epoch': 0.5}


 17%|█▋        | 940/5538 [12:16<55:55,  1.37it/s]

{'loss': 1.7666, 'grad_norm': 0.7879363298416138, 'learning_rate': 4.151318165402673e-05, 'epoch': 0.51}


 17%|█▋        | 950/5538 [12:23<54:19,  1.41it/s]  

{'loss': 1.8345, 'grad_norm': 0.7536249160766602, 'learning_rate': 4.1422896352473816e-05, 'epoch': 0.51}


 17%|█▋        | 960/5538 [12:29<46:23,  1.64it/s]

{'loss': 1.7687, 'grad_norm': 0.9117267727851868, 'learning_rate': 4.133261105092091e-05, 'epoch': 0.52}


 18%|█▊        | 970/5538 [12:36<54:19,  1.40it/s]  

{'loss': 1.9825, 'grad_norm': 1.113742470741272, 'learning_rate': 4.124232574936801e-05, 'epoch': 0.53}


 18%|█▊        | 980/5538 [12:41<41:10,  1.85it/s]

{'loss': 1.7662, 'grad_norm': 0.7962827086448669, 'learning_rate': 4.11520404478151e-05, 'epoch': 0.53}


 18%|█▊        | 990/5538 [12:47<37:09,  2.04it/s]

{'loss': 1.8227, 'grad_norm': 0.9702363610267639, 'learning_rate': 4.106175514626219e-05, 'epoch': 0.54}


 18%|█▊        | 1000/5538 [12:52<40:21,  1.87it/s]

{'loss': 1.8343, 'grad_norm': 1.3977845907211304, 'learning_rate': 4.097146984470928e-05, 'epoch': 0.54}


 18%|█▊        | 1010/5538 [12:58<57:33,  1.31it/s]  

{'loss': 1.7071, 'grad_norm': 0.7000864148139954, 'learning_rate': 4.0881184543156375e-05, 'epoch': 0.55}


 18%|█▊        | 1020/5538 [13:06<46:57,  1.60it/s]  

{'loss': 1.5963, 'grad_norm': 1.0614999532699585, 'learning_rate': 4.079089924160347e-05, 'epoch': 0.55}


 19%|█▊        | 1030/5538 [13:12<44:50,  1.68it/s]

{'loss': 1.8937, 'grad_norm': 0.9256168603897095, 'learning_rate': 4.0700613940050565e-05, 'epoch': 0.56}


 19%|█▉        | 1040/5538 [13:18<47:31,  1.58it/s]

{'loss': 1.7599, 'grad_norm': 0.9471101760864258, 'learning_rate': 4.0610328638497654e-05, 'epoch': 0.56}


 19%|█▉        | 1050/5538 [13:25<1:04:37,  1.16it/s]

{'loss': 1.9593, 'grad_norm': 0.9363771677017212, 'learning_rate': 4.052004333694475e-05, 'epoch': 0.57}


 19%|█▉        | 1060/5538 [13:32<58:03,  1.29it/s]  

{'loss': 1.6992, 'grad_norm': 0.7511029243469238, 'learning_rate': 4.042975803539184e-05, 'epoch': 0.57}


 19%|█▉        | 1070/5538 [13:38<43:20,  1.72it/s]  

{'loss': 1.5933, 'grad_norm': 1.2100608348846436, 'learning_rate': 4.033947273383893e-05, 'epoch': 0.58}


 20%|█▉        | 1080/5538 [13:45<59:14,  1.25it/s]

{'loss': 1.9283, 'grad_norm': 0.7060545682907104, 'learning_rate': 4.024918743228603e-05, 'epoch': 0.59}


 20%|█▉        | 1090/5538 [13:50<47:33,  1.56it/s]

{'loss': 1.8168, 'grad_norm': 1.0365084409713745, 'learning_rate': 4.0158902130733116e-05, 'epoch': 0.59}


 20%|█▉        | 1100/5538 [13:57<35:59,  2.05it/s]  

{'loss': 1.7576, 'grad_norm': 1.2207707166671753, 'learning_rate': 4.006861682918021e-05, 'epoch': 0.6}


 20%|██        | 1110/5538 [14:03<49:30,  1.49it/s]

{'loss': 1.8674, 'grad_norm': 0.952212929725647, 'learning_rate': 3.997833152762731e-05, 'epoch': 0.6}


 20%|██        | 1120/5538 [14:10<38:23,  1.92it/s]  

{'loss': 1.8403, 'grad_norm': 0.8669528365135193, 'learning_rate': 3.98880462260744e-05, 'epoch': 0.61}


 20%|██        | 1130/5538 [14:14<28:55,  2.54it/s]

{'loss': 1.9415, 'grad_norm': 1.2032324075698853, 'learning_rate': 3.979776092452149e-05, 'epoch': 0.61}


 21%|██        | 1140/5538 [14:22<56:08,  1.31it/s]  

{'loss': 1.7073, 'grad_norm': 0.797305703163147, 'learning_rate': 3.970747562296858e-05, 'epoch': 0.62}


 21%|██        | 1150/5538 [14:29<57:00,  1.28it/s]

{'loss': 1.7366, 'grad_norm': 1.1880298852920532, 'learning_rate': 3.9617190321415675e-05, 'epoch': 0.62}


 21%|██        | 1160/5538 [14:37<46:12,  1.58it/s]  

{'loss': 1.9338, 'grad_norm': 1.4435477256774902, 'learning_rate': 3.952690501986277e-05, 'epoch': 0.63}


 21%|██        | 1170/5538 [14:45<1:06:23,  1.10it/s]

{'loss': 1.8659, 'grad_norm': 0.8962294459342957, 'learning_rate': 3.943661971830986e-05, 'epoch': 0.63}


 21%|██▏       | 1180/5538 [14:52<1:01:30,  1.18it/s]

{'loss': 1.5548, 'grad_norm': 1.0566977262496948, 'learning_rate': 3.9346334416756954e-05, 'epoch': 0.64}


 21%|██▏       | 1190/5538 [15:01<1:07:57,  1.07it/s]

{'loss': 1.6933, 'grad_norm': 0.8990573883056641, 'learning_rate': 3.925604911520405e-05, 'epoch': 0.64}


 22%|██▏       | 1200/5538 [15:10<1:06:29,  1.09it/s]

{'loss': 1.822, 'grad_norm': 0.9377164840698242, 'learning_rate': 3.916576381365114e-05, 'epoch': 0.65}


 22%|██▏       | 1210/5538 [15:20<1:02:57,  1.15it/s]

{'loss': 1.8737, 'grad_norm': 0.9209846258163452, 'learning_rate': 3.907547851209823e-05, 'epoch': 0.66}


 22%|██▏       | 1220/5538 [15:27<50:07,  1.44it/s]  

{'loss': 1.7278, 'grad_norm': 0.9030419588088989, 'learning_rate': 3.898519321054532e-05, 'epoch': 0.66}


 22%|██▏       | 1230/5538 [15:38<1:07:49,  1.06it/s]

{'loss': 1.8817, 'grad_norm': 0.7841484546661377, 'learning_rate': 3.8894907908992416e-05, 'epoch': 0.67}


 22%|██▏       | 1240/5538 [15:46<1:04:56,  1.10it/s]

{'loss': 1.5815, 'grad_norm': 1.169050693511963, 'learning_rate': 3.880462260743951e-05, 'epoch': 0.67}


 23%|██▎       | 1250/5538 [15:54<48:51,  1.46it/s]  

{'loss': 1.8016, 'grad_norm': 0.9695791006088257, 'learning_rate': 3.871433730588661e-05, 'epoch': 0.68}


 23%|██▎       | 1260/5538 [16:02<57:05,  1.25it/s]

{'loss': 1.8177, 'grad_norm': 0.8585286736488342, 'learning_rate': 3.8624052004333695e-05, 'epoch': 0.68}


 23%|██▎       | 1270/5538 [16:08<33:27,  2.13it/s]

{'loss': 1.9518, 'grad_norm': 1.0909608602523804, 'learning_rate': 3.853376670278079e-05, 'epoch': 0.69}


 23%|██▎       | 1280/5538 [16:18<1:22:11,  1.16s/it]

{'loss': 1.9608, 'grad_norm': 0.7609909772872925, 'learning_rate': 3.844348140122788e-05, 'epoch': 0.69}


 23%|██▎       | 1290/5538 [16:27<1:13:08,  1.03s/it]

{'loss': 1.8757, 'grad_norm': 0.8785742521286011, 'learning_rate': 3.8353196099674974e-05, 'epoch': 0.7}


 23%|██▎       | 1300/5538 [16:38<1:10:05,  1.01it/s]

{'loss': 1.8163, 'grad_norm': 1.00180983543396, 'learning_rate': 3.826291079812207e-05, 'epoch': 0.7}


 24%|██▎       | 1310/5538 [16:50<1:32:40,  1.32s/it]

{'loss': 1.8699, 'grad_norm': 1.1745020151138306, 'learning_rate': 3.817262549656916e-05, 'epoch': 0.71}


 24%|██▍       | 1320/5538 [17:03<1:44:02,  1.48s/it]

{'loss': 2.0099, 'grad_norm': 0.9541602730751038, 'learning_rate': 3.8082340195016254e-05, 'epoch': 0.72}


 24%|██▍       | 1330/5538 [17:16<1:43:58,  1.48s/it]

{'loss': 1.7941, 'grad_norm': 1.194696307182312, 'learning_rate': 3.799205489346335e-05, 'epoch': 0.72}


 24%|██▍       | 1340/5538 [17:29<1:25:03,  1.22s/it]

{'loss': 1.8291, 'grad_norm': 0.8053624033927917, 'learning_rate': 3.7901769591910444e-05, 'epoch': 0.73}


 24%|██▍       | 1350/5538 [17:40<1:15:25,  1.08s/it]

{'loss': 1.9845, 'grad_norm': 0.8668246269226074, 'learning_rate': 3.781148429035753e-05, 'epoch': 0.73}


 25%|██▍       | 1360/5538 [17:48<55:01,  1.27it/s]  

{'loss': 1.9111, 'grad_norm': 0.9432799220085144, 'learning_rate': 3.772119898880462e-05, 'epoch': 0.74}


 25%|██▍       | 1370/5538 [17:56<51:05,  1.36it/s]  

{'loss': 1.7726, 'grad_norm': 0.7921280264854431, 'learning_rate': 3.7630913687251716e-05, 'epoch': 0.74}


 25%|██▍       | 1380/5538 [18:01<29:03,  2.38it/s]  

{'loss': 1.6457, 'grad_norm': 0.7302278876304626, 'learning_rate': 3.754062838569881e-05, 'epoch': 0.75}


 25%|██▌       | 1390/5538 [18:05<31:45,  2.18it/s]

{'loss': 1.855, 'grad_norm': 0.7948406934738159, 'learning_rate': 3.74503430841459e-05, 'epoch': 0.75}


 25%|██▌       | 1400/5538 [18:10<28:15,  2.44it/s]

{'loss': 1.5568, 'grad_norm': 0.769368588924408, 'learning_rate': 3.7360057782592995e-05, 'epoch': 0.76}


 25%|██▌       | 1410/5538 [18:14<43:02,  1.60it/s]

{'loss': 1.6763, 'grad_norm': 1.0197564363479614, 'learning_rate': 3.726977248104009e-05, 'epoch': 0.76}


 26%|██▌       | 1420/5538 [18:20<46:12,  1.49it/s]

{'loss': 1.8775, 'grad_norm': 1.6636837720870972, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}


 26%|██▌       | 1430/5538 [18:27<57:32,  1.19it/s]

{'loss': 1.7393, 'grad_norm': 0.9662932753562927, 'learning_rate': 3.7089201877934274e-05, 'epoch': 0.77}


 26%|██▌       | 1440/5538 [18:35<42:19,  1.61it/s]  

{'loss': 1.8366, 'grad_norm': 0.9638604521751404, 'learning_rate': 3.699891657638136e-05, 'epoch': 0.78}


 26%|██▌       | 1450/5538 [18:43<1:01:47,  1.10it/s]

{'loss': 1.8933, 'grad_norm': 1.0412836074829102, 'learning_rate': 3.690863127482846e-05, 'epoch': 0.79}


 26%|██▋       | 1460/5538 [18:52<56:44,  1.20it/s]  

{'loss': 1.7114, 'grad_norm': 0.8038713932037354, 'learning_rate': 3.6818345973275554e-05, 'epoch': 0.79}


 27%|██▋       | 1470/5538 [18:59<42:32,  1.59it/s]  

{'loss': 1.8574, 'grad_norm': 1.1008892059326172, 'learning_rate': 3.672806067172265e-05, 'epoch': 0.8}


 27%|██▋       | 1480/5538 [19:05<56:19,  1.20it/s]

{'loss': 1.6812, 'grad_norm': 0.9325599670410156, 'learning_rate': 3.663777537016974e-05, 'epoch': 0.8}


 27%|██▋       | 1490/5538 [19:13<48:56,  1.38it/s]  

{'loss': 1.7024, 'grad_norm': 0.9266835451126099, 'learning_rate': 3.654749006861683e-05, 'epoch': 0.81}


 27%|██▋       | 1500/5538 [19:18<39:20,  1.71it/s]

{'loss': 1.9717, 'grad_norm': 0.9111840724945068, 'learning_rate': 3.645720476706392e-05, 'epoch': 0.81}


 27%|██▋       | 1510/5538 [19:23<29:36,  2.27it/s]

{'loss': 1.9283, 'grad_norm': 0.9872065186500549, 'learning_rate': 3.6366919465511016e-05, 'epoch': 0.82}


 27%|██▋       | 1520/5538 [19:30<31:17,  2.14it/s]  

{'loss': 1.7826, 'grad_norm': 1.0094488859176636, 'learning_rate': 3.627663416395811e-05, 'epoch': 0.82}


 28%|██▊       | 1530/5538 [19:35<32:48,  2.04it/s]

{'loss': 1.9012, 'grad_norm': 0.775837242603302, 'learning_rate': 3.61863488624052e-05, 'epoch': 0.83}


 28%|██▊       | 1540/5538 [19:40<30:24,  2.19it/s]

{'loss': 1.6963, 'grad_norm': 0.6568133234977722, 'learning_rate': 3.6096063560852295e-05, 'epoch': 0.83}


 28%|██▊       | 1550/5538 [19:46<48:00,  1.38it/s]

{'loss': 1.7662, 'grad_norm': 1.4289215803146362, 'learning_rate': 3.600577825929939e-05, 'epoch': 0.84}


 28%|██▊       | 1560/5538 [19:52<36:51,  1.80it/s]

{'loss': 1.9163, 'grad_norm': 1.507603645324707, 'learning_rate': 3.5915492957746486e-05, 'epoch': 0.85}


 28%|██▊       | 1570/5538 [19:59<36:10,  1.83it/s]  

{'loss': 1.7214, 'grad_norm': 1.1663612127304077, 'learning_rate': 3.5825207656193574e-05, 'epoch': 0.85}


 29%|██▊       | 1580/5538 [20:07<1:09:17,  1.05s/it]

{'loss': 1.6806, 'grad_norm': 0.9973610043525696, 'learning_rate': 3.573492235464066e-05, 'epoch': 0.86}


 29%|██▊       | 1590/5538 [20:17<46:48,  1.41it/s]  

{'loss': 1.9506, 'grad_norm': 0.9228989481925964, 'learning_rate': 3.564463705308776e-05, 'epoch': 0.86}


 29%|██▉       | 1600/5538 [20:24<52:12,  1.26it/s]  

{'loss': 1.6906, 'grad_norm': 0.7713722586631775, 'learning_rate': 3.5554351751534853e-05, 'epoch': 0.87}


 29%|██▉       | 1610/5538 [20:34<51:08,  1.28it/s]  

{'loss': 1.8263, 'grad_norm': 0.9157726764678955, 'learning_rate': 3.546406644998194e-05, 'epoch': 0.87}


 29%|██▉       | 1620/5538 [20:43<1:08:17,  1.05s/it]

{'loss': 1.826, 'grad_norm': 0.7964504957199097, 'learning_rate': 3.537378114842904e-05, 'epoch': 0.88}


 29%|██▉       | 1630/5538 [20:53<46:56,  1.39it/s]  

{'loss': 1.8307, 'grad_norm': 1.1585484743118286, 'learning_rate': 3.528349584687613e-05, 'epoch': 0.88}


 30%|██▉       | 1640/5538 [21:01<1:08:37,  1.06s/it]

{'loss': 1.7573, 'grad_norm': 0.8231692910194397, 'learning_rate': 3.519321054532322e-05, 'epoch': 0.89}


 30%|██▉       | 1650/5538 [21:07<51:28,  1.26it/s]  

{'loss': 1.7228, 'grad_norm': 0.9693516492843628, 'learning_rate': 3.5102925243770316e-05, 'epoch': 0.89}


 30%|██▉       | 1660/5538 [21:15<43:43,  1.48it/s]  

{'loss': 1.7917, 'grad_norm': 0.8724406957626343, 'learning_rate': 3.5012639942217405e-05, 'epoch': 0.9}


 30%|███       | 1670/5538 [21:23<1:05:36,  1.02s/it]

{'loss': 1.7762, 'grad_norm': 1.2063342332839966, 'learning_rate': 3.49223546406645e-05, 'epoch': 0.9}


 30%|███       | 1680/5538 [21:32<47:55,  1.34it/s]  

{'loss': 1.7953, 'grad_norm': 0.6880330443382263, 'learning_rate': 3.4832069339111595e-05, 'epoch': 0.91}


 31%|███       | 1690/5538 [21:37<36:34,  1.75it/s]

{'loss': 1.8439, 'grad_norm': 0.7816957831382751, 'learning_rate': 3.474178403755869e-05, 'epoch': 0.92}


 31%|███       | 1700/5538 [21:46<1:12:47,  1.14s/it]

{'loss': 1.7843, 'grad_norm': 0.696368396282196, 'learning_rate': 3.465149873600578e-05, 'epoch': 0.92}


 31%|███       | 1710/5538 [21:58<1:28:58,  1.39s/it]

{'loss': 1.9701, 'grad_norm': 1.2972025871276855, 'learning_rate': 3.4561213434452874e-05, 'epoch': 0.93}


 31%|███       | 1720/5538 [22:07<1:16:24,  1.20s/it]

{'loss': 1.6916, 'grad_norm': 1.087615728378296, 'learning_rate': 3.447092813289996e-05, 'epoch': 0.93}


 31%|███       | 1730/5538 [22:14<45:05,  1.41it/s]  

{'loss': 1.8819, 'grad_norm': 0.8951455354690552, 'learning_rate': 3.438064283134706e-05, 'epoch': 0.94}


 31%|███▏      | 1740/5538 [22:24<53:50,  1.18it/s]  

{'loss': 1.8082, 'grad_norm': 0.9364892840385437, 'learning_rate': 3.4290357529794153e-05, 'epoch': 0.94}


 32%|███▏      | 1750/5538 [22:31<39:05,  1.61it/s]  

{'loss': 1.9542, 'grad_norm': 1.0447875261306763, 'learning_rate': 3.420007222824124e-05, 'epoch': 0.95}


 32%|███▏      | 1760/5538 [22:41<1:09:44,  1.11s/it]

{'loss': 1.7794, 'grad_norm': 0.9895156621932983, 'learning_rate': 3.410978692668834e-05, 'epoch': 0.95}


 32%|███▏      | 1770/5538 [22:50<44:36,  1.41it/s]  

{'loss': 1.7927, 'grad_norm': 2.1817872524261475, 'learning_rate': 3.401950162513543e-05, 'epoch': 0.96}


 32%|███▏      | 1780/5538 [22:56<40:02,  1.56it/s]

{'loss': 1.5904, 'grad_norm': 0.8917847275733948, 'learning_rate': 3.392921632358253e-05, 'epoch': 0.96}


 32%|███▏      | 1790/5538 [23:00<29:38,  2.11it/s]

{'loss': 1.8522, 'grad_norm': 0.8007518649101257, 'learning_rate': 3.3838931022029616e-05, 'epoch': 0.97}


 33%|███▎      | 1800/5538 [23:07<38:46,  1.61it/s]

{'loss': 1.543, 'grad_norm': 0.7430760264396667, 'learning_rate': 3.3748645720476705e-05, 'epoch': 0.98}


 33%|███▎      | 1810/5538 [23:15<1:00:01,  1.04it/s]

{'loss': 1.6932, 'grad_norm': 0.7997698187828064, 'learning_rate': 3.36583604189238e-05, 'epoch': 0.98}


 33%|███▎      | 1820/5538 [23:22<35:51,  1.73it/s]  

{'loss': 1.8053, 'grad_norm': 0.9231575131416321, 'learning_rate': 3.3568075117370895e-05, 'epoch': 0.99}


 33%|███▎      | 1830/5538 [23:27<40:49,  1.51it/s]

{'loss': 1.7176, 'grad_norm': 1.0609110593795776, 'learning_rate': 3.3477789815817984e-05, 'epoch': 0.99}


 33%|███▎      | 1840/5538 [23:33<43:13,  1.43it/s]

{'loss': 1.8386, 'grad_norm': 0.9416612386703491, 'learning_rate': 3.338750451426508e-05, 'epoch': 1.0}


 33%|███▎      | 1846/5538 [23:40<55:07,  1.12it/s]  

IndexError: piece id is out of range.

## inference

In [1]:
from transformers import T5ForConditionalGeneration, T5Tokenizer
last_checkpoint = "./third-results/checkpoint-4000"

finetuned_model = T5ForConditionalGeneration.from_pretrained(last_checkpoint).to("cuda")
finetuned_tokenizer = T5Tokenizer.from_pretrained(last_checkpoint)
question="what are marine toxins?"

input_text = "Please answer this medical related question: "+question
input_ids = finetuned_tokenizer(input_text, return_tensors="pt").input_ids.to("cuda")

outputs = finetuned_model.generate(
    input_ids,
    max_length=200,
    min_length=20,
    repetition_penalty=2.0
)
answer = finetuned_tokenizer.decode(outputs[0], skip_special_tokens=True)
from textwrap import fill

print(fill(answer, width=100))

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Marine toxins are substances that cause damage to the body's tissues and organs. They can be toxic,
but they do not affect other parts of the body. The most common types of marine toxins include:
Lymphadenoma (the type of lymph nodes in the blood) Affected people may have an increased risk for
developing certain diseases such as cancer or heart disease. Some cases of this condition occur when
there is too much fluid in the brain or spinal cord. In some instances, it causes pain, swelling,
loss of appetite, nausea, vomiting, diarrhea, headache, seizures, fatigue, weight gain, muscle
weakness, difficulty swallowing, and/or confusion.
