In [26]:
#!pip install torch torchtext transformers sentencepiece pandas tqdm datasets

In [40]:
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast
import datasets
from tqdm import tqdm
import time

In [41]:
# Load data set from huggingface
data_sample = load_dataset("QuyenAnhDE/Diseases_Symptoms")

Repo card metadata block was not found. Setting CardData to empty.


In [42]:
data_sample

DatasetDict({
    train: Dataset({
        features: ['Code', 'Name', 'Symptoms', 'Treatments'],
        num_rows: 400
    })
})

In [43]:
# Convert to a pandas dataframe
updated_data = [{'Name': item['Name'], 'Symptoms': item['Symptoms']} for item in data_sample['train']]
df = pd.DataFrame(updated_data)
df

Unnamed: 0,Name,Symptoms
0,Panic disorder,"Palpitations, Sweating, Trembling, Shortness o..."
1,Vocal cord polyp,"Hoarseness, Vocal Changes, Vocal Fatigue"
2,Turner syndrome,"Short stature, Gonadal dysgenesis, Webbed neck..."
3,Cryptorchidism,"Absence or undescended testicle(s), empty scro..."
4,Ethylene glycol poisoning-1,"Nausea, vomiting, abdominal pain, General mala..."
...,...,...
395,Urinary Stones (Kidney Stones),"Severe abdominal or back pain, blood in urine,..."
396,Osteoporosis,"Fragile bones, loss of height over time, back ..."
397,Rheumatoid Arthritis,"Joint pain, stiffness, swelling, fatigue, loss..."
398,Type 1 Diabetes,"Frequent urination, Increased thirst, Weight loss"


In [5]:
# Just extract the Symptoms
df['Symptoms'] = df['Symptoms'].apply(lambda x: ', '.join(x.split(', ')))
display(df.head())

Unnamed: 0,Name,Symptoms
0,Panic disorder,"Palpitations, Sweating, Trembling, Shortness o..."
1,Vocal cord polyp,"Hoarseness, Vocal Changes, Vocal Fatigue"
2,Turner syndrome,"Short stature, Gonadal dysgenesis, Webbed neck..."
3,Cryptorchidism,"Absence or undescended testicle(s), empty scro..."
4,Ethylene glycol poisoning-1,"Nausea, vomiting, abdominal pain, General mala..."


In [6]:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split

In [7]:
# If you have an NVIDIA GPU attached, use 'cuda'
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    # If Apple Silicon, set to 'mps' - otherwise 'cpu' (not advised)
    try:
        device = torch.device('mps')
    except Exception:
        device = torch.device('cpu')

In [8]:
device

device(type='cuda')

In [9]:
# The tokenizer turns texts to numbers (and vice-versa)
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

# The transformer
model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/762 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/353M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [10]:
model

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [11]:
# Model params
BATCH_SIZE = 64

In [12]:
df.describe()

Unnamed: 0,Name,Symptoms
count,400,400
unique,392,395
top,Sciatica,"Swelling, pain, dry mouth, bad taste"
freq,3,3


In [13]:
# Dataset Prep
class LanguageDataset(Dataset):
    """
    An extension of the Dataset object to:
      - Make training loop cleaner
      - Make ingestion easier from pandas df's
    """
    def __init__(self, df, tokenizer):
        self.labels = df.columns
        self.data = df.to_dict(orient='records')
        self.tokenizer = tokenizer
        x = self.fittest_max_length(df)  # Fix here
        self.max_length = x

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.data[idx][self.labels[0]]
        y = self.data[idx][self.labels[1]]
        text = f"{x} | {y}"
        tokens = self.tokenizer.encode_plus(text, return_tensors='pt', max_length=128, padding='max_length', truncation=True)
        return tokens

    def fittest_max_length(self, df):  # Fix here
        """
        Smallest power of two larger than the longest term in the data set.
        Important to set up max length to speed training time.
        """
        max_length = max(len(max(df[self.labels[0]], key=len)), len(max(df[self.labels[1]], key=len)))
        x = 2
        while x < max_length: x = x * 2
        return x

# Cast the Huggingface data set as a LanguageDataset we defined above
data_sample = LanguageDataset(df, tokenizer)


In [17]:
# Create train, valid
train_size = int(0.8 * len(data_sample))
valid_size = len(data_sample) - train_size
train_data, valid_data = random_split(data_sample, [train_size, valid_size])

In [18]:
# Make the iterators
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)

In [33]:
for batch in train_loader:
    break


print(batch.keys())

print( len(batch['input_ids'].shape) )

dict_keys(['input_ids', 'attention_mask'])
3


In [36]:
batch['input_ids'] .shape

torch.Size([64, 1, 128])

In [52]:
# Set the number of epochs
num_epochs = 50

In [53]:
# Training parameters
batch_size = BATCH_SIZE
model_name = 'distilgpt2'
gpu = 0

In [56]:
# Set the learning rate and loss function
## CrossEntropyLoss measures how close answers to the truth.
## More punishing for high confidence wrong answers
criterion = nn.CrossEntropyLoss(ignore_index = tokenizer.pad_token_id)
optimizer = optim.Adam(model.parameters(), lr=5e-4)
tokenizer.pad_token = tokenizer.eos_token

In [57]:
# Init a results dataframe
results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
                                'training_loss', 'validation_loss', 'epoch_duration_sec'])

In [58]:
# The training loop
for epoch in range(num_epochs):
    start_time = time.time()  # Start the timer for the epoch

    # Training
    ## This line tells the model we're in 'learning mode'
    model.train()
    epoch_training_loss = 0
    train_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs} Batch Size: {batch_size}, Transformer: {model_name}")
    for batch in train_iterator:
        optimizer.zero_grad()
        inputs = batch['input_ids'].squeeze(1).to(device)
        targets = inputs.clone()
        outputs = model(input_ids=inputs, labels=targets)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        train_iterator.set_postfix({'Training Loss': loss.item()})
        epoch_training_loss += loss.item()
    avg_epoch_training_loss = epoch_training_loss / len(train_iterator)

    # Validation
    ## This line below tells the model to 'stop learning'
    model.eval()
    epoch_validation_loss = 0
    total_loss = 0
    valid_iterator = tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
    with torch.no_grad():
        for batch in valid_iterator:
            inputs = batch['input_ids'].squeeze(1).to(device)
            targets = inputs.clone()
            outputs = model(input_ids=inputs, labels=targets)
            loss = outputs.loss
            total_loss += loss
            valid_iterator.set_postfix({'Validation Loss': loss.item()})
            epoch_validation_loss += loss.item()

    avg_epoch_validation_loss = epoch_validation_loss / len(valid_loader)

    end_time = time.time()  # End the timer for the epoch
    epoch_duration_sec = end_time - start_time  # Calculate the duration in seconds

    new_row = {'transformer': model_name,
               'batch_size': batch_size,
               'gpu': gpu,
               'epoch': epoch+1,
               'training_loss': avg_epoch_training_loss,
               'validation_loss': avg_epoch_validation_loss,
               'epoch_duration_sec': epoch_duration_sec}  # Add epoch_duration to the dataframe

    results.loc[len(results)] = new_row
    print(f"Epoch: {epoch+1}, Validation Loss: {total_loss/len(valid_loader)}")

Training Epoch 1/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.82it/s, Training Loss=0.531]
Validation Epoch 1/50: 100%|██████████| 2/2 [00:00<00:00,  8.21it/s, Validation Loss=0.618]


Epoch: 1, Validation Loss: 0.6878660917282104


Training Epoch 2/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.417]
Validation Epoch 2/50: 100%|██████████| 2/2 [00:00<00:00,  8.56it/s, Validation Loss=0.598]


Epoch: 2, Validation Loss: 0.6892023086547852


Training Epoch 3/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.291]
Validation Epoch 3/50: 100%|██████████| 2/2 [00:00<00:00,  8.63it/s, Validation Loss=0.607]


Epoch: 3, Validation Loss: 0.715057373046875


Training Epoch 4/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.266]
Validation Epoch 4/50: 100%|██████████| 2/2 [00:00<00:00,  8.62it/s, Validation Loss=0.614]


Epoch: 4, Validation Loss: 0.7309919595718384


Training Epoch 5/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.238]
Validation Epoch 5/50: 100%|██████████| 2/2 [00:00<00:00,  8.63it/s, Validation Loss=0.642]


Epoch: 5, Validation Loss: 0.7670233845710754


Training Epoch 6/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.196]
Validation Epoch 6/50: 100%|██████████| 2/2 [00:00<00:00,  8.63it/s, Validation Loss=0.676]


Epoch: 6, Validation Loss: 0.8099104166030884


Training Epoch 7/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.165]
Validation Epoch 7/50: 100%|██████████| 2/2 [00:00<00:00,  8.64it/s, Validation Loss=0.684]


Epoch: 7, Validation Loss: 0.8321250677108765


Training Epoch 8/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.176]
Validation Epoch 8/50: 100%|██████████| 2/2 [00:00<00:00,  8.62it/s, Validation Loss=0.705]


Epoch: 8, Validation Loss: 0.8640424609184265


Training Epoch 9/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.123]
Validation Epoch 9/50: 100%|██████████| 2/2 [00:00<00:00,  8.63it/s, Validation Loss=0.717]


Epoch: 9, Validation Loss: 0.8810622692108154


Training Epoch 10/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0944]
Validation Epoch 10/50: 100%|██████████| 2/2 [00:00<00:00,  8.64it/s, Validation Loss=0.744]


Epoch: 10, Validation Loss: 0.9048079252243042


Training Epoch 11/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0849]
Validation Epoch 11/50: 100%|██████████| 2/2 [00:00<00:00,  8.54it/s, Validation Loss=0.777]


Epoch: 11, Validation Loss: 0.9317841529846191


Training Epoch 12/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0703]
Validation Epoch 12/50: 100%|██████████| 2/2 [00:00<00:00,  8.61it/s, Validation Loss=0.802]


Epoch: 12, Validation Loss: 0.9637366533279419


Training Epoch 13/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0675]
Validation Epoch 13/50: 100%|██████████| 2/2 [00:00<00:00,  8.66it/s, Validation Loss=0.799]


Epoch: 13, Validation Loss: 0.9651272296905518


Training Epoch 14/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0571]
Validation Epoch 14/50: 100%|██████████| 2/2 [00:00<00:00,  8.65it/s, Validation Loss=0.815]


Epoch: 14, Validation Loss: 0.982245683670044


Training Epoch 15/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0465]
Validation Epoch 15/50: 100%|██████████| 2/2 [00:00<00:00,  8.53it/s, Validation Loss=0.844]


Epoch: 15, Validation Loss: 1.014608383178711


Training Epoch 16/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0462]
Validation Epoch 16/50: 100%|██████████| 2/2 [00:00<00:00,  8.49it/s, Validation Loss=0.834]


Epoch: 16, Validation Loss: 1.0135736465454102


Training Epoch 17/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0397]
Validation Epoch 17/50: 100%|██████████| 2/2 [00:00<00:00,  8.53it/s, Validation Loss=0.86]


Epoch: 17, Validation Loss: 1.0278730392456055


Training Epoch 18/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0404]
Validation Epoch 18/50: 100%|██████████| 2/2 [00:00<00:00,  8.45it/s, Validation Loss=0.863]


Epoch: 18, Validation Loss: 1.0238735675811768


Training Epoch 19/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0413]
Validation Epoch 19/50: 100%|██████████| 2/2 [00:00<00:00,  8.57it/s, Validation Loss=0.864]


Epoch: 19, Validation Loss: 1.0278997421264648


Training Epoch 20/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0412]
Validation Epoch 20/50: 100%|██████████| 2/2 [00:00<00:00,  8.58it/s, Validation Loss=0.888]


Epoch: 20, Validation Loss: 1.0433002710342407


Training Epoch 21/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0407]
Validation Epoch 21/50: 100%|██████████| 2/2 [00:00<00:00,  8.58it/s, Validation Loss=0.889]


Epoch: 21, Validation Loss: 1.0448321104049683


Training Epoch 22/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0365]
Validation Epoch 22/50: 100%|██████████| 2/2 [00:00<00:00,  8.56it/s, Validation Loss=0.882]


Epoch: 22, Validation Loss: 1.0421209335327148


Training Epoch 23/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0306]
Validation Epoch 23/50: 100%|██████████| 2/2 [00:00<00:00,  8.43it/s, Validation Loss=0.899]


Epoch: 23, Validation Loss: 1.0673978328704834


Training Epoch 24/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0299]
Validation Epoch 24/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.918]


Epoch: 24, Validation Loss: 1.081768274307251


Training Epoch 25/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0273]
Validation Epoch 25/50: 100%|██████████| 2/2 [00:00<00:00,  8.53it/s, Validation Loss=0.91]


Epoch: 25, Validation Loss: 1.0751359462738037


Training Epoch 26/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0263]
Validation Epoch 26/50: 100%|██████████| 2/2 [00:00<00:00,  8.53it/s, Validation Loss=0.939]


Epoch: 26, Validation Loss: 1.0925984382629395


Training Epoch 27/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0275]
Validation Epoch 27/50: 100%|██████████| 2/2 [00:00<00:00,  6.65it/s, Validation Loss=0.954]


Epoch: 27, Validation Loss: 1.094219446182251


Training Epoch 28/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.84it/s, Training Loss=0.0215]
Validation Epoch 28/50: 100%|██████████| 2/2 [00:00<00:00,  8.52it/s, Validation Loss=0.965]


Epoch: 28, Validation Loss: 1.0981156826019287


Training Epoch 29/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.023] 
Validation Epoch 29/50: 100%|██████████| 2/2 [00:00<00:00,  8.50it/s, Validation Loss=0.962]


Epoch: 29, Validation Loss: 1.107938528060913


Training Epoch 30/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0257]
Validation Epoch 30/50: 100%|██████████| 2/2 [00:00<00:00,  8.40it/s, Validation Loss=0.963]


Epoch: 30, Validation Loss: 1.1124930381774902


Training Epoch 31/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0293]
Validation Epoch 31/50: 100%|██████████| 2/2 [00:00<00:00,  8.59it/s, Validation Loss=0.95]


Epoch: 31, Validation Loss: 1.1017218828201294


Training Epoch 32/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0247]
Validation Epoch 32/50: 100%|██████████| 2/2 [00:00<00:00,  8.59it/s, Validation Loss=0.934]


Epoch: 32, Validation Loss: 1.090221643447876


Training Epoch 33/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0231]
Validation Epoch 33/50: 100%|██████████| 2/2 [00:00<00:00,  8.66it/s, Validation Loss=0.954]


Epoch: 33, Validation Loss: 1.1076908111572266


Training Epoch 34/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0273]
Validation Epoch 34/50: 100%|██████████| 2/2 [00:00<00:00,  8.65it/s, Validation Loss=0.956]


Epoch: 34, Validation Loss: 1.1124439239501953


Training Epoch 35/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0237]
Validation Epoch 35/50: 100%|██████████| 2/2 [00:00<00:00,  8.65it/s, Validation Loss=0.978]


Epoch: 35, Validation Loss: 1.1297359466552734


Training Epoch 36/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0284]
Validation Epoch 36/50: 100%|██████████| 2/2 [00:00<00:00,  8.63it/s, Validation Loss=0.959]


Epoch: 36, Validation Loss: 1.11360502243042


Training Epoch 37/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0228]
Validation Epoch 37/50: 100%|██████████| 2/2 [00:00<00:00,  8.65it/s, Validation Loss=0.956]


Epoch: 37, Validation Loss: 1.109984278678894


Training Epoch 38/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0232]
Validation Epoch 38/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.959]


Epoch: 38, Validation Loss: 1.1110011339187622


Training Epoch 39/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0227]
Validation Epoch 39/50: 100%|██████████| 2/2 [00:00<00:00,  8.50it/s, Validation Loss=0.964]


Epoch: 39, Validation Loss: 1.1160163879394531


Training Epoch 40/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0244]
Validation Epoch 40/50: 100%|██████████| 2/2 [00:00<00:00,  8.49it/s, Validation Loss=0.968]


Epoch: 40, Validation Loss: 1.123340129852295


Training Epoch 41/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0234]
Validation Epoch 41/50: 100%|██████████| 2/2 [00:00<00:00,  8.52it/s, Validation Loss=0.977]


Epoch: 41, Validation Loss: 1.1328998804092407


Training Epoch 42/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0195]
Validation Epoch 42/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.993]


Epoch: 42, Validation Loss: 1.1464756727218628


Training Epoch 43/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0204]
Validation Epoch 43/50: 100%|██████████| 2/2 [00:00<00:00,  8.52it/s, Validation Loss=0.992]


Epoch: 43, Validation Loss: 1.1464056968688965


Training Epoch 44/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0179]
Validation Epoch 44/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=1]  


Epoch: 44, Validation Loss: 1.1497448682785034


Training Epoch 45/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0189]
Validation Epoch 45/50: 100%|██████████| 2/2 [00:00<00:00,  8.53it/s, Validation Loss=0.99]


Epoch: 45, Validation Loss: 1.1434646844863892


Training Epoch 46/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0185]
Validation Epoch 46/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.987]


Epoch: 46, Validation Loss: 1.1411021947860718


Training Epoch 47/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0133]
Validation Epoch 47/50: 100%|██████████| 2/2 [00:00<00:00,  8.49it/s, Validation Loss=0.987]


Epoch: 47, Validation Loss: 1.1409238576889038


Training Epoch 48/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0179]
Validation Epoch 48/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.986]


Epoch: 48, Validation Loss: 1.138296365737915


Training Epoch 49/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.85it/s, Training Loss=0.0159]
Validation Epoch 49/50: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s, Validation Loss=0.988]


Epoch: 49, Validation Loss: 1.1431069374084473


Training Epoch 50/50 Batch Size: 64, Transformer: distilgpt2: 100%|██████████| 5/5 [00:02<00:00,  1.86it/s, Training Loss=0.0206]
Validation Epoch 50/50: 100%|██████████| 2/2 [00:00<00:00,  8.64it/s, Validation Loss=0.967]

Epoch: 50, Validation Loss: 1.129504919052124





In [59]:
index = 15
input_str = df['Name'][index]
output_str = df['Symptoms'][index]

print(f"Input: {input_str}", f"Output: {output_str}", sep='\n')

Input: Warthin tumor
Output: Painless lump or swelling, Facial changes


In [60]:
input_str = input_str 
input_ids = tokenizer.encode(input_str, return_tensors='pt').to(device)
output = model.generate(input_ids,max_length=20,num_return_sequences=1,do_sample=True,top_k=8,top_p=0.95,temperature=0.5,repetition_penalty=1.2)
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)

print("-----------------------")
print(decoded_output)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


-----------------------
Warthin tumor | Joint pain, stiffness, swelling, limited range of motion


In [None]:
torch.save(model, 'SmallMedLM.pt')

In [None]:
torch.save(model, 'drive/My Drive/SmallMedLM.pt')