In [58]:
import torch
import pandas as pd
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import train_test_split
from transformers import pipeline, set_seed
from datasets import Dataset
from transformers import BertTokenizer
from transformers import GPT2LMHeadModel, GPT2Tokenizer, TrainingArguments,Trainer

In [59]:
dataset = pd.read_csv('../data/Advisor_data/final.csv')
dataset

Unnamed: 0,Disease ID,Disease Name,Affected Plant Species,Symptom Description,Diagnosis Method,Treatment Options
0,1,Powdery Mildew,Cucumber; Zucchini; Grapes,"Powdery white fungal growth appears on leaves,...","Diagnosed by visual inspection, especially ide...",Treat with sulfur-based fungicides or neem oil...
1,2,Downy Mildew,Cucumbers; Lettuce; Grapes,Downy mildew manifests as yellowish patches on...,Diagnosed through careful visual inspection fo...,Apply copper-based fungicides or organic treat...
2,3,Leaf Spot,Tomatoes; Potatoes; Lettuce,"Leaf spot causes small, round lesions that sta...",Diagnosis is typically based on visual inspect...,Use fungicides or bactericides depending on th...
3,4,Root Rot,Tomatoes; Lettuce; Cucumbers,Root rot leads to wilting and yellowing of the...,Diagnosed by examining the roots for signs of ...,Improve soil drainage and avoid over-watering....
4,5,Late Blight,Potatoes; Tomatoes,"Late blight results in dark, water-soaked lesi...",Diagnosis is typically confirmed by observing ...,Use fungicides containing copper or metalaxyl ...
...,...,...,...,...,...,...
645,646,Tulip Bud Rot,Tulips,Rotting buds; Soft discolored flowers,Visual inspection,Fungicides; Pruning infected flowers
646,647,Daffodil Rust,Daffodils,Orange pustules on leaves,Visual inspection,Fungicides; Pruning infected leaves
647,648,Orchid Botrytis Leaf Spot,Orchids,Water-soaked lesions on leaves,Visual inspection,Fungicides; Pruning infected parts
648,649,Rose Verticillium Wilt,Roses,Yellowing leaves; Stunted growth,Soil tests,Fungicides; Resistant varieties


In [60]:
device = 'cuda' if torch.cuda.is_available() else "cpu"
device

'cpu'

In [4]:
dataset = dataset.rename(columns = {
    'Disease ID' : 'id',
    'Disease Name': 'disease_name',
    'Affected Plant Species': 'affected_species',
    'Symptom Description': 'symptoms',
    'Diagnosis Method': 'diagnosis',
    'Treatment Options': 'treatment'
})

In [5]:
dataset.head()

Unnamed: 0,id,disease_name,affected_species,symptoms,diagnosis,treatment
0,1,Powdery Mildew,Cucumber; Zucchini; Grapes,"Powdery white fungal growth appears on leaves,...","Diagnosed by visual inspection, especially ide...",Treat with sulfur-based fungicides or neem oil...
1,2,Downy Mildew,Cucumbers; Lettuce; Grapes,Downy mildew manifests as yellowish patches on...,Diagnosed through careful visual inspection fo...,Apply copper-based fungicides or organic treat...
2,3,Leaf Spot,Tomatoes; Potatoes; Lettuce,"Leaf spot causes small, round lesions that sta...",Diagnosis is typically based on visual inspect...,Use fungicides or bactericides depending on th...
3,4,Root Rot,Tomatoes; Lettuce; Cucumbers,Root rot leads to wilting and yellowing of the...,Diagnosed by examining the roots for signs of ...,Improve soil drainage and avoid over-watering....
4,5,Late Blight,Potatoes; Tomatoes,"Late blight results in dark, water-soaked lesi...",Diagnosis is typically confirmed by observing ...,Use fungicides containing copper or metalaxyl ...


In [53]:
def format_species(species):
    # Split the string into a list of words by ';' and remove extra spaces
    species_list = [s.strip() for s in species.split(';')]
    
    # If there's only one item, return it in singular form
    if len(species_list) == 1:
        return f"It affects the plants such as {species_list[0]}"
    
    # If multiple items, join them with commas and "and" before the last item
    species_str = ', '.join(species_list[:-1]) + f" and {species_list[-1]}"
    return f"It affects the plants such as {species_str}"

In [56]:
dataset['affected_species'] = dataset['affected_species'].apply(format_species)
dataset['affected_species']

0      It affects the plants such as Cucumber, Zucchi...
1      It affects the plants such as Cucumbers, Lettu...
2      It affects the plants such as Tomatoes, Potato...
3      It affects the plants such as Tomatoes, Lettuc...
4      It affects the plants such as Potatoes and Tom...
                             ...                        
645                 It affects the plants such as Tulips
646              It affects the plants such as Daffodils
647                It affects the plants such as Orchids
648                  It affects the plants such as Roses
649             It affects the plants such as Sunflowers
Name: affected_species, Length: 650, dtype: object

In [6]:
dataset['input_text'] = dataset['disease_name']+':'+ dataset['affected_species'] + ' ' + dataset['symptoms'] + ' ' + dataset['diagnosis']
dataset['output_text'] = dataset['treatment']

In [7]:
# Split the dataset into train and test
train_data,test_data = train_test_split(dataset,test_size=0.2)

In [8]:
train_data

Unnamed: 0,id,disease_name,affected_species,symptoms,diagnosis,treatment,input_text,output_text
634,635,Daffodil Stem Blight,Daffodils,Dark lesions on stems; Wilting flowers,Visual inspection,Fungicides; Pruning infected parts,Daffodil Stem Blight Daffodils Dark lesions on...,Fungicides; Pruning infected parts
499,500,Walnut Xylella fastidiosa,Walnuts,Yellowing leaves; Dieback,Soil tests; Lab tests (PCR),Fungicides; Pruning infected parts,Walnut Xylella fastidiosa Walnuts Yellowing le...,Fungicides; Pruning infected parts
23,24,Fusarium Wilt,Tomato; Banana; Cotton,"Fusarium wilt causes yellowing, wilting, and s...",Diagnosed by observing yellowing and wilting s...,Use resistant plant varieties and treat seeds ...,Fusarium Wilt Tomato; Banana; Cotton Fusarium ...,Use resistant plant varieties and treat seeds ...
604,605,Daffodil Leaf Spot,Daffodils,Small dark spots with yellow margins on leaves,Visual inspection,Fungicides; Pruning infected parts,Daffodil Leaf Spot Daffodils Small dark spots...,Fungicides; Pruning infected parts
447,448,Bacterial Necrosis,Cashews; Almonds,Sunken lesions on branches; Yellowing leaves,Visual inspection; Lab tests (PCR),Copper-based fungicides,Bacterial Necrosis Cashews; Almonds Sunken les...,Copper-based fungicides
...,...,...,...,...,...,...,...,...
494,495,Cashew Dieback,Cashews,Stunted growth; Dead branches,Visual inspection,Fungicides; Pruning infected parts,Cashew Dieback Cashews Stunted growth; Dead br...,Fungicides; Pruning infected parts
492,493,Almond Leaf Scorch,Almonds,Yellowing of leaf edges; Dry spots,Visual inspection,Fungicides; Proper irrigation,Almond Leaf Scorch Almonds Yellowing of leaf e...,Fungicides; Proper irrigation
75,76,Brown Stripe Downy Mildew,Maize,"Brown stripe downy mildew causes brown, stripe...",Diagnosed by observing symptoms and fungal cul...,Apply fungicides like metalaxyl and improve cr...,Brown Stripe Downy Mildew Maize Brown stripe d...,Apply fungicides like metalaxyl and improve cr...
510,511,Almond Mite,Cashews,Yellowing leaves; Deformed growth,Visual inspection; Insect traps,Insecticide; Pruning infected leaves,Almond Mite Cashews Yellowing leaves; Deformed...,Insecticide; Pruning infected leaves


In [9]:
# Convert the pandas DataFrame to Hugging Face dataset format
train_dataset = Dataset.from_pandas(train_data[['input_text', 'output_text']])
test_dataset = Dataset.from_pandas(test_data[['input_text', 'output_text']])


In [10]:
train_dataset

Dataset({
    features: ['input_text', 'output_text', '__index_level_0__'],
    num_rows: 520
})

In [11]:
# Load pre-trained GPT-2 model and tokenizer
model_name = 'gpt2'  # You can also try 'EleutherAI/gpt-neo-2.7B' for GPT-Neo
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = GPT2Tokenizer.from_pretrained(model_name)

In [12]:
# GPT-2 requires padding token to be set explicitly
tokenizer.pad_token = tokenizer.eos_token  # Set padding token to eos_token for GPT-2

# Step 3: Preprocessing the Dataset (Tokenization)
def tokenize_function(examples):
    return tokenizer(examples['input_text'], truncation=True, padding='max_length', max_length=512)

# Apply the tokenization function to both train and test datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

Map: 100%|██████████| 520/520 [00:00<00:00, 2799.38 examples/s]
Map: 100%|██████████| 130/130 [00:00<00:00, 2256.39 examples/s]


In [26]:
# Modify the inputs and labels for language modeling (shift labels for next-token prediction)
def shift_labels(batch):
    # Shift the labels by one token
    batch["labels"] = batch["input_ids"].copy()
    batch["labels"] = [ids[1:] + [tokenizer.pad_token_id] for ids in batch["labels"]]  # Shift labels
    return batch

train_dataset = train_dataset.map(shift_labels, batched=True)
test_dataset = test_dataset.map(shift_labels, batched=True)


[A
Map: 100%|██████████| 520/520 [00:00<00:00, 2507.19 examples/s]

Map: 100%|██████████| 130/130 [00:00<00:00, 2679.79 examples/s]


In [27]:
train_dataset

Dataset({
    features: ['input_text', 'output_text', '__index_level_0__', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 520
})

In [28]:

# Settting up the Trainer and Training Arguments
training_args = TrainingArguments(
    output_dir="./gpt2_finetuned",  # Output directory where model checkpoints will be saved
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    learning_rate=5e-5,  # Learning rate for fine-tuning
    per_device_train_batch_size=2,  # Batch size for training
    per_device_eval_batch_size=2,  # Batch size for evaluation
    num_train_epochs=3,  # Number of epochs to train
    weight_decay=0.01,  # Weight decay to prevent overfitting
    save_total_limit=2,  # Limit the number of saved checkpoints
    logging_dir="./logs",  # Directory for logs
    logging_steps=10,  # Log every 10 steps
    push_to_hub=False,  # Disable pushing to Hugging Face Hub
)



In [29]:
trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset= train_dataset,
    eval_dataset=test_dataset
)

In [30]:
trainer.train()

  0%|          | 0/780 [07:04<?, ?it/s]
                                                 
  1%|▏         | 10/780 [00:45<58:36,  4.57s/it]

{'loss': 2.4588, 'grad_norm': 2.9180593490600586, 'learning_rate': 4.935897435897436e-05, 'epoch': 0.04}


                                                
  3%|▎         | 20/780 [01:30<57:31,  4.54s/it]

{'loss': 0.475, 'grad_norm': 1.4517159461975098, 'learning_rate': 4.871794871794872e-05, 'epoch': 0.08}


                                                
  4%|▍         | 30/780 [02:16<56:32,  4.52s/it]

{'loss': 0.3077, 'grad_norm': 1.1693569421768188, 'learning_rate': 4.8076923076923084e-05, 'epoch': 0.12}


                                                
  5%|▌         | 40/780 [03:01<56:54,  4.61s/it]

{'loss': 0.3937, 'grad_norm': 1.2612810134887695, 'learning_rate': 4.7435897435897435e-05, 'epoch': 0.15}


                                                
  6%|▋         | 50/780 [03:47<55:58,  4.60s/it]

{'loss': 0.2484, 'grad_norm': 1.2207701206207275, 'learning_rate': 4.67948717948718e-05, 'epoch': 0.19}


                                                
  8%|▊         | 60/780 [04:33<55:02,  4.59s/it]

{'loss': 0.2704, 'grad_norm': 1.969726324081421, 'learning_rate': 4.615384615384616e-05, 'epoch': 0.23}


                                                
  9%|▉         | 70/780 [05:19<53:25,  4.52s/it]

{'loss': 0.2662, 'grad_norm': 4.54311466217041, 'learning_rate': 4.5512820512820516e-05, 'epoch': 0.27}


                                                
 10%|█         | 80/780 [06:03<52:13,  4.48s/it]

{'loss': 0.3039, 'grad_norm': 1.7762080430984497, 'learning_rate': 4.4871794871794874e-05, 'epoch': 0.31}


                                                
 12%|█▏        | 90/780 [06:48<51:04,  4.44s/it]

{'loss': 0.2375, 'grad_norm': 1.3824553489685059, 'learning_rate': 4.423076923076923e-05, 'epoch': 0.35}


                                                 
 13%|█▎        | 100/780 [07:32<50:00,  4.41s/it]

{'loss': 0.2308, 'grad_norm': 1.9171643257141113, 'learning_rate': 4.358974358974359e-05, 'epoch': 0.38}


                                                 
 14%|█▍        | 110/780 [08:17<49:36,  4.44s/it]

{'loss': 0.2522, 'grad_norm': 2.2309203147888184, 'learning_rate': 4.294871794871795e-05, 'epoch': 0.42}


                                                 
 15%|█▌        | 120/780 [09:01<48:06,  4.37s/it]

{'loss': 0.2398, 'grad_norm': 1.6443493366241455, 'learning_rate': 4.230769230769231e-05, 'epoch': 0.46}


                                                 
 17%|█▋        | 130/780 [09:45<48:09,  4.45s/it]

{'loss': 0.2166, 'grad_norm': 1.9878499507904053, 'learning_rate': 4.166666666666667e-05, 'epoch': 0.5}


                                                 
 18%|█▊        | 140/780 [10:30<47:50,  4.49s/it]

{'loss': 0.4245, 'grad_norm': 1.518293023109436, 'learning_rate': 4.1025641025641023e-05, 'epoch': 0.54}


                                                 
 19%|█▉        | 150/780 [11:15<47:05,  4.49s/it]

{'loss': 0.3334, 'grad_norm': 1.519122838973999, 'learning_rate': 4.038461538461539e-05, 'epoch': 0.58}


                                                 
 21%|██        | 160/780 [12:00<45:49,  4.43s/it]

{'loss': 0.3257, 'grad_norm': 1.343960165977478, 'learning_rate': 3.974358974358974e-05, 'epoch': 0.62}


                                                 
 22%|██▏       | 170/780 [12:45<45:11,  4.45s/it]

{'loss': 0.2684, 'grad_norm': 1.3471447229385376, 'learning_rate': 3.9102564102564105e-05, 'epoch': 0.65}


                                                 
 23%|██▎       | 180/780 [13:29<44:46,  4.48s/it]

{'loss': 0.21, 'grad_norm': 1.3680946826934814, 'learning_rate': 3.846153846153846e-05, 'epoch': 0.69}


                                                 
 24%|██▍       | 190/780 [14:14<43:48,  4.46s/it]

{'loss': 0.3495, 'grad_norm': 1.2451105117797852, 'learning_rate': 3.782051282051282e-05, 'epoch': 0.73}


                                                 
 26%|██▌       | 200/780 [15:00<44:31,  4.61s/it]

{'loss': 0.179, 'grad_norm': 1.9198607206344604, 'learning_rate': 3.717948717948718e-05, 'epoch': 0.77}


                                                 
 27%|██▋       | 210/780 [15:45<41:48,  4.40s/it]

{'loss': 0.2325, 'grad_norm': 1.4309719800949097, 'learning_rate': 3.653846153846154e-05, 'epoch': 0.81}


                                                 
 28%|██▊       | 220/780 [16:30<42:10,  4.52s/it]

{'loss': 0.2593, 'grad_norm': 1.2105952501296997, 'learning_rate': 3.58974358974359e-05, 'epoch': 0.85}


                                                 
 29%|██▉       | 230/780 [17:16<41:23,  4.52s/it]

{'loss': 0.1571, 'grad_norm': 1.3449122905731201, 'learning_rate': 3.525641025641026e-05, 'epoch': 0.88}


                                                 
 31%|███       | 240/780 [18:00<39:37,  4.40s/it]

{'loss': 0.1987, 'grad_norm': 1.9150651693344116, 'learning_rate': 3.461538461538462e-05, 'epoch': 0.92}


                                                 
 32%|███▏      | 250/780 [18:46<40:15,  4.56s/it]

{'loss': 0.2024, 'grad_norm': 1.8233810663223267, 'learning_rate': 3.397435897435898e-05, 'epoch': 0.96}


                                                 
 33%|███▎      | 260/780 [19:29<37:55,  4.38s/it]

{'loss': 0.2594, 'grad_norm': 1.6448605060577393, 'learning_rate': 3.3333333333333335e-05, 'epoch': 1.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                    

 33%|███▎      | 260/780 [20:41<37:55,  4.38s/it]
[A
[A

{'eval_loss': 0.18184229731559753, 'eval_runtime': 71.6663, 'eval_samples_per_second': 1.814, 'eval_steps_per_second': 0.907, 'epoch': 1.0}


                                                   
 35%|███▍      | 270/780 [21:24<43:53,  5.16s/it]

{'loss': 0.201, 'grad_norm': 1.5832115411758423, 'learning_rate': 3.269230769230769e-05, 'epoch': 1.04}


                                                 
 36%|███▌      | 280/780 [22:09<37:44,  4.53s/it]

{'loss': 0.1616, 'grad_norm': 1.8737314939498901, 'learning_rate': 3.205128205128206e-05, 'epoch': 1.08}


                                                 
 37%|███▋      | 290/780 [22:54<35:47,  4.38s/it]

{'loss': 0.2393, 'grad_norm': 2.306267261505127, 'learning_rate': 3.141025641025641e-05, 'epoch': 1.12}


                                                 
 38%|███▊      | 300/780 [23:36<35:00,  4.38s/it]

{'loss': 0.1389, 'grad_norm': 1.726609706878662, 'learning_rate': 3.0769230769230774e-05, 'epoch': 1.15}


                                                 
 40%|███▉      | 310/780 [24:21<34:56,  4.46s/it]

{'loss': 0.1631, 'grad_norm': 1.8067774772644043, 'learning_rate': 3.012820512820513e-05, 'epoch': 1.19}


                                                 
 41%|████      | 320/780 [25:06<33:56,  4.43s/it]

{'loss': 0.151, 'grad_norm': 1.6355897188186646, 'learning_rate': 2.948717948717949e-05, 'epoch': 1.23}


                                                 
 42%|████▏     | 330/780 [25:50<33:06,  4.42s/it]

{'loss': 0.2038, 'grad_norm': 2.085676431655884, 'learning_rate': 2.8846153846153845e-05, 'epoch': 1.27}


                                                 
 44%|████▎     | 340/780 [26:34<32:27,  4.43s/it]

{'loss': 0.1763, 'grad_norm': 2.9818949699401855, 'learning_rate': 2.8205128205128207e-05, 'epoch': 1.31}


                                                 
 45%|████▍     | 350/780 [27:18<31:40,  4.42s/it]

{'loss': 0.2341, 'grad_norm': 1.367983341217041, 'learning_rate': 2.756410256410257e-05, 'epoch': 1.35}


                                                 
 46%|████▌     | 360/780 [28:03<31:07,  4.45s/it]

{'loss': 0.2058, 'grad_norm': 1.3778948783874512, 'learning_rate': 2.6923076923076923e-05, 'epoch': 1.38}


                                                 
 47%|████▋     | 370/780 [28:47<30:12,  4.42s/it]

{'loss': 0.1942, 'grad_norm': 0.8361591100692749, 'learning_rate': 2.6282051282051285e-05, 'epoch': 1.42}


                                                 
 49%|████▊     | 380/780 [29:32<29:30,  4.43s/it]

{'loss': 0.1947, 'grad_norm': 2.047224283218384, 'learning_rate': 2.564102564102564e-05, 'epoch': 1.46}


                                                 
 50%|█████     | 390/780 [30:16<29:15,  4.50s/it]

{'loss': 0.1245, 'grad_norm': 1.7003039121627808, 'learning_rate': 2.5e-05, 'epoch': 1.5}


                                                 
 51%|█████▏    | 400/780 [31:01<28:58,  4.58s/it]

{'loss': 0.2225, 'grad_norm': 1.4357749223709106, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}


                                                 
 53%|█████▎    | 410/780 [31:46<27:42,  4.49s/it]

{'loss': 0.2297, 'grad_norm': 2.745622158050537, 'learning_rate': 2.3717948717948718e-05, 'epoch': 1.58}


                                                 
 54%|█████▍    | 420/780 [32:31<26:57,  4.49s/it]

{'loss': 0.3133, 'grad_norm': 1.315982699394226, 'learning_rate': 2.307692307692308e-05, 'epoch': 1.62}


                                                 
 55%|█████▌    | 430/780 [33:16<26:13,  4.50s/it]

{'loss': 0.1763, 'grad_norm': 2.0714974403381348, 'learning_rate': 2.2435897435897437e-05, 'epoch': 1.65}


                                                 
 56%|█████▋    | 440/780 [34:01<25:15,  4.46s/it]

{'loss': 0.1997, 'grad_norm': 1.387109398841858, 'learning_rate': 2.1794871794871795e-05, 'epoch': 1.69}


                                                 
 58%|█████▊    | 450/780 [34:45<24:26,  4.44s/it]

{'loss': 0.2173, 'grad_norm': 2.2341675758361816, 'learning_rate': 2.1153846153846154e-05, 'epoch': 1.73}


                                                 
 59%|█████▉    | 460/780 [35:30<24:02,  4.51s/it]

{'loss': 0.1665, 'grad_norm': 2.8453948497772217, 'learning_rate': 2.0512820512820512e-05, 'epoch': 1.77}


                                                 
 60%|██████    | 470/780 [36:22<29:39,  5.74s/it]

{'loss': 0.1587, 'grad_norm': 1.720787763595581, 'learning_rate': 1.987179487179487e-05, 'epoch': 1.81}


                                                 
 62%|██████▏   | 480/780 [37:07<22:31,  4.50s/it]

{'loss': 0.1413, 'grad_norm': 1.4468289613723755, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.85}


                                                 
 63%|██████▎   | 490/780 [37:51<21:25,  4.43s/it]

{'loss': 0.1715, 'grad_norm': 1.3408221006393433, 'learning_rate': 1.858974358974359e-05, 'epoch': 1.88}


                                                 
 64%|██████▍   | 500/780 [38:36<20:25,  4.38s/it]

{'loss': 0.1832, 'grad_norm': 1.4620161056518555, 'learning_rate': 1.794871794871795e-05, 'epoch': 1.92}


                                                 
 65%|██████▌   | 510/780 [39:21<19:44,  4.39s/it]

{'loss': 0.2143, 'grad_norm': 2.809804916381836, 'learning_rate': 1.730769230769231e-05, 'epoch': 1.96}


                                                 
 67%|██████▋   | 520/780 [40:04<18:38,  4.30s/it]

{'loss': 0.1695, 'grad_norm': 2.0505611896514893, 'learning_rate': 1.6666666666666667e-05, 'epoch': 2.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                 
[A                                    

 67%|██████▋   | 520/780 [41:15<18:38,  4.30s/it]
[A
[A

{'eval_loss': 0.16074250638484955, 'eval_runtime': 71.1527, 'eval_samples_per_second': 1.827, 'eval_steps_per_second': 0.914, 'epoch': 2.0}


                                                   
 68%|██████▊   | 530/780 [41:58<21:22,  5.13s/it]

{'loss': 0.1676, 'grad_norm': 1.4050089120864868, 'learning_rate': 1.602564102564103e-05, 'epoch': 2.04}


                                                 
 69%|██████▉   | 540/780 [42:42<17:47,  4.45s/it]

{'loss': 0.1699, 'grad_norm': 1.621198058128357, 'learning_rate': 1.5384615384615387e-05, 'epoch': 2.08}


                                                 
 71%|███████   | 550/780 [43:27<17:00,  4.44s/it]

{'loss': 0.1539, 'grad_norm': 1.7315987348556519, 'learning_rate': 1.4743589743589745e-05, 'epoch': 2.12}


                                                 
 72%|███████▏  | 560/780 [44:11<16:29,  4.50s/it]

{'loss': 0.1381, 'grad_norm': 1.6584547758102417, 'learning_rate': 1.4102564102564104e-05, 'epoch': 2.15}


                                                 
 73%|███████▎  | 570/780 [44:57<16:03,  4.59s/it]

{'loss': 0.1568, 'grad_norm': 1.326914668083191, 'learning_rate': 1.3461538461538462e-05, 'epoch': 2.19}


                                                 
 74%|███████▍  | 580/780 [45:43<15:17,  4.59s/it]

{'loss': 0.1772, 'grad_norm': 1.2647120952606201, 'learning_rate': 1.282051282051282e-05, 'epoch': 2.23}


                                                 
 76%|███████▌  | 590/780 [46:28<14:21,  4.53s/it]

{'loss': 0.1959, 'grad_norm': 1.7430015802383423, 'learning_rate': 1.217948717948718e-05, 'epoch': 2.27}


                                                 
 77%|███████▋  | 600/780 [47:13<13:24,  4.47s/it]

{'loss': 0.2316, 'grad_norm': 2.2990643978118896, 'learning_rate': 1.153846153846154e-05, 'epoch': 2.31}


                                                 
 78%|███████▊  | 610/780 [47:59<12:50,  4.53s/it]

{'loss': 0.1263, 'grad_norm': 1.5915508270263672, 'learning_rate': 1.0897435897435898e-05, 'epoch': 2.35}


                                                 
 79%|███████▉  | 620/780 [48:44<11:51,  4.45s/it]

{'loss': 0.1027, 'grad_norm': 1.4683643579483032, 'learning_rate': 1.0256410256410256e-05, 'epoch': 2.38}


                                                 
 81%|████████  | 630/780 [49:29<11:07,  4.45s/it]

{'loss': 0.1766, 'grad_norm': 1.8124926090240479, 'learning_rate': 9.615384615384616e-06, 'epoch': 2.42}


                                                 
 82%|████████▏ | 640/780 [50:13<10:22,  4.45s/it]

{'loss': 0.1367, 'grad_norm': 1.1199349164962769, 'learning_rate': 8.974358974358976e-06, 'epoch': 2.46}


                                                 
 83%|████████▎ | 650/780 [50:58<09:43,  4.49s/it]

{'loss': 0.1632, 'grad_norm': 2.987386703491211, 'learning_rate': 8.333333333333334e-06, 'epoch': 2.5}


                                                 
 85%|████████▍ | 660/780 [51:43<08:55,  4.47s/it]

{'loss': 0.1358, 'grad_norm': 1.5392460823059082, 'learning_rate': 7.692307692307694e-06, 'epoch': 2.54}


                                                 
 86%|████████▌ | 670/780 [52:27<08:13,  4.48s/it]

{'loss': 0.1519, 'grad_norm': 1.2219432592391968, 'learning_rate': 7.051282051282052e-06, 'epoch': 2.58}


                                                 
 87%|████████▋ | 680/780 [53:12<07:13,  4.34s/it]

{'loss': 0.138, 'grad_norm': 1.4626113176345825, 'learning_rate': 6.41025641025641e-06, 'epoch': 2.62}


                                                 
 88%|████████▊ | 690/780 [53:56<06:31,  4.35s/it]

{'loss': 0.1614, 'grad_norm': 2.3443222045898438, 'learning_rate': 5.76923076923077e-06, 'epoch': 2.65}


                                                 
 90%|████████▉ | 700/780 [54:38<05:33,  4.17s/it]

{'loss': 0.1828, 'grad_norm': 1.169180154800415, 'learning_rate': 5.128205128205128e-06, 'epoch': 2.69}


                                                 
 91%|█████████ | 710/780 [55:20<04:55,  4.22s/it]

{'loss': 0.2049, 'grad_norm': 2.40272855758667, 'learning_rate': 4.487179487179488e-06, 'epoch': 2.73}


                                                 
 92%|█████████▏| 720/780 [56:02<04:08,  4.14s/it]

{'loss': 0.1464, 'grad_norm': 1.3954224586486816, 'learning_rate': 3.846153846153847e-06, 'epoch': 2.77}


                                                 
 94%|█████████▎| 730/780 [56:44<03:27,  4.15s/it]

{'loss': 0.2007, 'grad_norm': 1.5068132877349854, 'learning_rate': 3.205128205128205e-06, 'epoch': 2.81}


                                                 
 95%|█████████▍| 740/780 [57:27<02:51,  4.30s/it]

{'loss': 0.1106, 'grad_norm': 1.1655679941177368, 'learning_rate': 2.564102564102564e-06, 'epoch': 2.85}


                                                 
 96%|█████████▌| 750/780 [58:09<02:08,  4.30s/it]

{'loss': 0.178, 'grad_norm': 1.3250336647033691, 'learning_rate': 1.9230769230769234e-06, 'epoch': 2.88}


                                                 
 97%|█████████▋| 760/780 [58:51<01:24,  4.23s/it]

{'loss': 0.1923, 'grad_norm': 1.6255768537521362, 'learning_rate': 1.282051282051282e-06, 'epoch': 2.92}


                                                 
 99%|█████████▊| 770/780 [59:33<00:41,  4.16s/it]

{'loss': 0.1858, 'grad_norm': 1.5783381462097168, 'learning_rate': 6.41025641025641e-07, 'epoch': 2.96}


                                                   
100%|██████████| 780/780 [1:00:15<00:00,  4.18s/it]

{'loss': 0.2478, 'grad_norm': 3.1815764904022217, 'learning_rate': 0.0, 'epoch': 3.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   
[A                                      

100%|██████████| 780/780 [1:01:27<00:00,  4.18s/it]
[A
                                                   
100%|██████████| 780/780 [1:01:27<00:00,  4.73s/it]

{'eval_loss': 0.15382349491119385, 'eval_runtime': 70.8919, 'eval_samples_per_second': 1.834, 'eval_steps_per_second': 0.917, 'epoch': 3.0}
{'train_runtime': 3687.8368, 'train_samples_per_second': 0.423, 'train_steps_per_second': 0.212, 'train_loss': 0.23827883157974633, 'epoch': 3.0}





TrainOutput(global_step=780, training_loss=0.23827883157974633, metrics={'train_runtime': 3687.8368, 'train_samples_per_second': 0.423, 'train_steps_per_second': 0.212, 'total_flos': 407615569920000.0, 'train_loss': 0.23827883157974633, 'epoch': 3.0})

In [31]:
model.save_pretrained("./gpt2_finetuned")
tokenizer.save_pretrained("./gpt2_finetuned")

('./gpt2_finetuned\\tokenizer_config.json',
 './gpt2_finetuned\\special_tokens_map.json',
 './gpt2_finetuned\\vocab.json',
 './gpt2_finetuned\\merges.txt',
 './gpt2_finetuned\\added_tokens.json')

In [32]:
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [39]:
input_text = "Describe the symptoms and causes of apple scab disease."
inputs = tokenizer(input_text, return_tensors="pt")

In [40]:
output = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1)

# Decode the output text
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print("Generated Text:")
print(generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Generated Text:
Describe the symptoms and causes of apple scab disease.osed Visual; testsmicroopy
