In [1]:
from datasets import load_dataset

dataset = load_dataset("duxprajapati/symptom-disease-dataset")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 5634
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1409
    })
})

In [3]:
#show the number of Name column which is not duplicated
#show that unique element of dataset["train"]['Name'] which is list type
labels= dataset["train"]['label']

unique_names = set(labels)


In [4]:
len(unique_names)

866

In [5]:
labels= dataset["test"]['label']

In [6]:
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
from transformers import AutoModel
import torch

# Load the base model
model_base = AutoModelForSequenceClassification.from_pretrained("duxprajapati/symptom-disease-model")


config = AutoConfig.from_pretrained("duxprajapati/symptom-disease-model")

In [7]:
print(model_base)

DistilBertForSequenceClassification(
  (distilbert): DistilBertModel(
    (embeddings): Embeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (transformer): Transformer(
      (layer): ModuleList(
        (0-5): 6 x TransformerBlock(
          (attention): MultiHeadSelfAttention(
            (dropout): Dropout(p=0.1, inplace=False)
            (q_lin): Linear(in_features=768, out_features=768, bias=True)
            (k_lin): Linear(in_features=768, out_features=768, bias=True)
            (v_lin): Linear(in_features=768, out_features=768, bias=True)
            (out_lin): Linear(in_features=768, out_features=768, bias=True)
          )
          (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (ffn): FFN(
            (dropout): Dropout(p=0.1, inplace=False)
 

In [8]:
import nltk
from nltk.corpus import wordnet
from random import randint

nltk.download('averaged_perceptron_tagger')
nltk.download('wordnet')

def get_synonyms(word):
    synonyms = set()
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name())
    if word in synonyms:
        synonyms.remove(word)
    return list(synonyms)

def synonym_replacement(sentence, num_replacements=1):
    words = nltk.word_tokenize(sentence)
    pos_tags = nltk.pos_tag(words)

    replacements = 0
    for i, (word, tag) in enumerate(pos_tags):
        if replacements >= num_replacements or wordnet.synsets(word) == []:
            continue
        synonyms = get_synonyms(word)
        if synonyms:
            synonym = synonyms[randint(0, len(synonyms) - 1)]
            words[i] = synonym
            replacements += 1

    return ' '.join(words)

# Example usage
sentence = "Fever and cough might indicate a flu infection."
augmented_sentence = synonym_replacement(sentence, 2)
print("Original:", sentence)
print("Augmented:", augmented_sentence)


[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\GamerPc\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\GamerPc\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


Original: Fever and cough might indicate a flu infection.
Augmented: febricity and coughing might indicate a flu infection .


In [9]:
from datasets import load_dataset

# Assuming the functions `get_synonyms` and `synonym_replacement` are defined as above

# Load your dataset
dataset = load_dataset("duxprajapati/symptom-disease-dataset")

# Function to augment the data
def augment_dataset(data, augmentation_function, num_replacements=1):
    augmented_texts = []
    labels = []
    for example in data:
        augmented_text = augmentation_function(example['text'], num_replacements)
        augmented_texts.append(augmented_text)
        labels.append(example['label'])
    return augmented_texts, labels

# Augment the training data
augmented_texts, labels = augment_dataset(dataset['train'], synonym_replacement, 2)

# Create a new dataset from the augmented texts and labels
augmented_dataset = {
    'text': augmented_texts + [example['text'] for example in dataset['train']],
    'label': labels + [example['label'] for example in dataset['train']]
}


In [10]:
augmented_dataset

{'text': ["single cause been having migraines and headaches . I ca n't sleep . My whole body is shaking and shivering . I feel dizzy sometimes .",
  'iodin own asthma and I get wheezing and breathing problems . I also have fevers , headaches , and I feel tired all the time .',
  "contract and symptom of primary ovarian insufficiency are similar to those of menopause or estrogen deficiency . They include : Irregular or skipped periods , which might be present for years or develop after a pregnancy or after stopping birth control pills Difficulty getting pregnant Hot flashes Night sweats Vaginal dryness Dry eyes Irritability or difficulty concentrating Decreased sexual desire When to see a health care provider If you 've missed your period for three months or more , see your health care provider to determine the cause . You can miss your period for a number of reasons — including pregnancy , stress , or a change in diet or exercise habits — but it 's best to get evaluated whenever your m

In [11]:
# Assuming 'dataset' is your original DatasetDict containing 'train' and 'test' splits
# And assuming 'augmented_texts' and 'labels' are your lists of augmented data and their corresponding labels

# Convert augmented data into a list of dictionaries
augmented_data = [{'text': text, 'label': label} for text, label in zip(augmented_texts, labels)]

# Append augmented data to the original dataset
for item in augmented_data:
    dataset['train'] = dataset['train'].add_item(item)

# Now, 'dataset['train']' contains both the original and augmented data as separate rows


In [12]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 11268
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1409
    })
})

In [13]:
# from googletrans import Translator, LANGUAGES

# translator = Translator()

# def translate_augment(text, lang='es'):
#     # Translate to the target language
#     translated = translator.translate(text, dest=lang).text
#     # Translate back to the original language
#     retranslated = translator.translate(translated, dest='en').text
#     return retranslated


In [14]:
# # Function to augment the data
# def augment_dataset_2(data, augmentation_function):
#     augmented_texts = []
#     labels = []
#     for example in data:
#         augmented_text = augmentation_function(example['text'],)
#         augmented_texts.append(augmented_text)
#         labels.append(example['label'])
#     return augmented_texts, labels

# # Augment the training data
# augmented_texts, labels = augment_dataset_2(dataset['train'], translate_augment)

# # Create a new dataset from the augmented texts and labels
# augmented_dataset = {
#     'text': augmented_texts + [example['text'] for example in dataset['train']],
#     'label': labels + [example['label'] for example in dataset['train']]
# }


In [15]:
# Assuming 'dataset' is your original DatasetDict containing 'train' and 'test' splits
# And assuming 'augmented_texts' and 'labels' are your lists of augmented data and their corresponding labels

# Convert augmented data into a list of dictionaries
# augmented_data = [{'text': text, 'label': label} for text, label in zip(augmented_texts, labels)]

# # Append augmented data to the original dataset
# for item in augmented_data:
#     dataset['train'] = dataset['train'].add_item(item)

# Now, 'dataset['train']' contains both the original and augmented data as separate rows

In [16]:
dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 11268
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 1409
    })
})

In [17]:
#fine tune the model
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    per_device_train_batch_size=24,
    num_train_epochs=3,
    learning_rate=2e-8,
    logging_dir='./logs',
    logging_steps=10,
    do_train=True,
    do_eval=True,
    output_dir='./results',
    overwrite_output_dir=True,
)





In [18]:
#tokenize the data
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("duxprajapati/symptom-disease-model")

tokenized_dataset = dataset.map(lambda examples: tokenizer(examples['text'], padding='max_length', truncation=True), batched=True)


Map: 100%|██████████| 11268/11268 [00:02<00:00, 5543.69 examples/s]


In [19]:
tokenized_dataset

DatasetDict({
    train: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 11268
    })
    test: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 1409
    })
})

In [20]:
trainer = Trainer(
    model=model_base,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [21]:
#show the result of the model
trainer.train()


  1%|          | 10/1410 [02:07<4:55:23, 12.66s/it]

{'loss': 0.343, 'grad_norm': 4.519108772277832, 'learning_rate': 1.9858156028368796e-08, 'epoch': 0.02}


  1%|▏         | 20/1410 [04:15<4:56:38, 12.80s/it]

{'loss': 0.239, 'grad_norm': 1.4740196466445923, 'learning_rate': 1.9716312056737588e-08, 'epoch': 0.04}


  2%|▏         | 30/1410 [06:23<4:54:11, 12.79s/it]

{'loss': 0.2721, 'grad_norm': 2.0385046005249023, 'learning_rate': 1.9574468085106384e-08, 'epoch': 0.06}


  3%|▎         | 40/1410 [08:29<4:51:09, 12.75s/it]

{'loss': 0.2751, 'grad_norm': 3.2401652336120605, 'learning_rate': 1.9432624113475176e-08, 'epoch': 0.09}


  4%|▎         | 50/1410 [10:38<4:51:36, 12.87s/it]

{'loss': 0.2991, 'grad_norm': 3.4341206550598145, 'learning_rate': 1.9290780141843972e-08, 'epoch': 0.11}


  4%|▍         | 60/1410 [12:48<4:50:58, 12.93s/it]

{'loss': 0.2769, 'grad_norm': 1.252944827079773, 'learning_rate': 1.9148936170212767e-08, 'epoch': 0.13}


  5%|▍         | 70/1410 [14:57<4:49:29, 12.96s/it]

{'loss': 0.3092, 'grad_norm': 4.060464382171631, 'learning_rate': 1.9007092198581563e-08, 'epoch': 0.15}


  6%|▌         | 80/1410 [17:07<4:47:43, 12.98s/it]

{'loss': 0.2828, 'grad_norm': 2.6505727767944336, 'learning_rate': 1.8865248226950355e-08, 'epoch': 0.17}


  6%|▋         | 90/1410 [19:17<4:45:31, 12.98s/it]

{'loss': 0.3266, 'grad_norm': 2.779618740081787, 'learning_rate': 1.872340425531915e-08, 'epoch': 0.19}


  7%|▋         | 100/1410 [21:26<4:41:35, 12.90s/it]

{'loss': 0.2695, 'grad_norm': 2.8971502780914307, 'learning_rate': 1.8581560283687943e-08, 'epoch': 0.21}


  8%|▊         | 110/1410 [23:35<4:38:24, 12.85s/it]

{'loss': 0.1779, 'grad_norm': 0.37143227458000183, 'learning_rate': 1.843971631205674e-08, 'epoch': 0.23}


  9%|▊         | 120/1410 [25:45<4:37:29, 12.91s/it]

{'loss': 0.335, 'grad_norm': 3.372220516204834, 'learning_rate': 1.829787234042553e-08, 'epoch': 0.26}


  9%|▉         | 130/1410 [27:54<4:35:49, 12.93s/it]

{'loss': 0.2788, 'grad_norm': 2.6434264183044434, 'learning_rate': 1.8156028368794327e-08, 'epoch': 0.28}


 10%|▉         | 140/1410 [30:04<4:33:39, 12.93s/it]

{'loss': 0.3164, 'grad_norm': 3.6421146392822266, 'learning_rate': 1.801418439716312e-08, 'epoch': 0.3}


 11%|█         | 150/1410 [32:14<4:31:57, 12.95s/it]

{'loss': 0.2886, 'grad_norm': 2.319040298461914, 'learning_rate': 1.7872340425531914e-08, 'epoch': 0.32}


 11%|█▏        | 160/1410 [34:23<4:29:27, 12.93s/it]

{'loss': 0.3056, 'grad_norm': 2.1156792640686035, 'learning_rate': 1.773049645390071e-08, 'epoch': 0.34}


 12%|█▏        | 170/1410 [36:33<4:27:01, 12.92s/it]

{'loss': 0.3251, 'grad_norm': 3.2693331241607666, 'learning_rate': 1.7588652482269506e-08, 'epoch': 0.36}


 13%|█▎        | 180/1410 [38:42<4:24:55, 12.92s/it]

{'loss': 0.2436, 'grad_norm': 1.7332932949066162, 'learning_rate': 1.7446808510638298e-08, 'epoch': 0.38}


 13%|█▎        | 190/1410 [40:52<4:22:12, 12.90s/it]

{'loss': 0.3194, 'grad_norm': 5.750783920288086, 'learning_rate': 1.7304964539007093e-08, 'epoch': 0.4}


 14%|█▍        | 200/1410 [43:02<4:21:18, 12.96s/it]

{'loss': 0.2553, 'grad_norm': 0.0353192500770092, 'learning_rate': 1.7163120567375886e-08, 'epoch': 0.43}


 15%|█▍        | 210/1410 [45:11<4:18:13, 12.91s/it]

{'loss': 0.2432, 'grad_norm': 1.9035440683364868, 'learning_rate': 1.702127659574468e-08, 'epoch': 0.45}


 16%|█▌        | 220/1410 [47:20<4:16:03, 12.91s/it]

{'loss': 0.3407, 'grad_norm': 2.464313507080078, 'learning_rate': 1.6879432624113477e-08, 'epoch': 0.47}


 16%|█▋        | 230/1410 [49:29<4:14:24, 12.94s/it]

{'loss': 0.3051, 'grad_norm': 3.648425579071045, 'learning_rate': 1.673758865248227e-08, 'epoch': 0.49}


 17%|█▋        | 240/1410 [51:39<4:11:46, 12.91s/it]

{'loss': 0.2184, 'grad_norm': 2.062157154083252, 'learning_rate': 1.6595744680851065e-08, 'epoch': 0.51}


 18%|█▊        | 250/1410 [53:48<4:08:10, 12.84s/it]

{'loss': 0.311, 'grad_norm': 4.987637519836426, 'learning_rate': 1.6453900709219857e-08, 'epoch': 0.53}


 18%|█▊        | 260/1410 [55:58<4:09:05, 13.00s/it]

{'loss': 0.2717, 'grad_norm': 5.834586143493652, 'learning_rate': 1.6312056737588653e-08, 'epoch': 0.55}


 19%|█▉        | 270/1410 [58:08<4:05:34, 12.92s/it]

{'loss': 0.2512, 'grad_norm': 2.9063501358032227, 'learning_rate': 1.6170212765957445e-08, 'epoch': 0.57}


 20%|█▉        | 280/1410 [1:00:17<4:02:20, 12.87s/it]

{'loss': 0.2091, 'grad_norm': 2.1859242916107178, 'learning_rate': 1.6028368794326244e-08, 'epoch': 0.6}


 21%|██        | 290/1410 [1:02:27<4:02:45, 13.00s/it]

{'loss': 0.2745, 'grad_norm': 3.105147361755371, 'learning_rate': 1.5886524822695036e-08, 'epoch': 0.62}


 21%|██▏       | 300/1410 [1:04:36<3:57:56, 12.86s/it]

{'loss': 0.3364, 'grad_norm': 2.4538135528564453, 'learning_rate': 1.5744680851063832e-08, 'epoch': 0.64}


 22%|██▏       | 310/1410 [1:06:46<3:57:22, 12.95s/it]

{'loss': 0.3175, 'grad_norm': 4.7908172607421875, 'learning_rate': 1.5602836879432624e-08, 'epoch': 0.66}


 23%|██▎       | 320/1410 [1:08:56<3:54:44, 12.92s/it]

{'loss': 0.2764, 'grad_norm': 2.5694448947906494, 'learning_rate': 1.546099290780142e-08, 'epoch': 0.68}


 23%|██▎       | 330/1410 [1:11:04<3:51:20, 12.85s/it]

{'loss': 0.3047, 'grad_norm': 1.8648723363876343, 'learning_rate': 1.5319148936170212e-08, 'epoch': 0.7}


 24%|██▍       | 340/1410 [1:13:14<3:50:20, 12.92s/it]

{'loss': 0.3228, 'grad_norm': 2.522742986679077, 'learning_rate': 1.5177304964539007e-08, 'epoch': 0.72}


 25%|██▍       | 350/1410 [1:15:23<3:48:05, 12.91s/it]

{'loss': 0.2546, 'grad_norm': 1.3586111068725586, 'learning_rate': 1.50354609929078e-08, 'epoch': 0.74}


 26%|██▌       | 360/1410 [1:17:33<3:46:02, 12.92s/it]

{'loss': 0.3039, 'grad_norm': 2.502718687057495, 'learning_rate': 1.4893617021276595e-08, 'epoch': 0.77}


 26%|██▌       | 370/1410 [1:19:42<3:43:59, 12.92s/it]

{'loss': 0.2718, 'grad_norm': 2.1951470375061035, 'learning_rate': 1.475177304964539e-08, 'epoch': 0.79}


 27%|██▋       | 380/1410 [1:21:52<3:41:58, 12.93s/it]

{'loss': 0.2903, 'grad_norm': 1.3533530235290527, 'learning_rate': 1.4609929078014187e-08, 'epoch': 0.81}


 28%|██▊       | 390/1410 [1:24:01<3:39:09, 12.89s/it]

{'loss': 0.2862, 'grad_norm': 2.8570706844329834, 'learning_rate': 1.446808510638298e-08, 'epoch': 0.83}


 28%|██▊       | 400/1410 [1:26:10<3:36:54, 12.89s/it]

{'loss': 0.2552, 'grad_norm': 2.824723243713379, 'learning_rate': 1.4326241134751774e-08, 'epoch': 0.85}


 29%|██▉       | 410/1410 [1:28:20<3:35:32, 12.93s/it]

{'loss': 0.245, 'grad_norm': 2.8757433891296387, 'learning_rate': 1.4184397163120568e-08, 'epoch': 0.87}


 30%|██▉       | 420/1410 [1:30:29<3:33:05, 12.91s/it]

{'loss': 0.3128, 'grad_norm': 3.1623713970184326, 'learning_rate': 1.4042553191489362e-08, 'epoch': 0.89}


 30%|███       | 430/1410 [1:32:39<3:31:28, 12.95s/it]

{'loss': 0.2222, 'grad_norm': 2.55362868309021, 'learning_rate': 1.3900709219858156e-08, 'epoch': 0.91}


 31%|███       | 440/1410 [1:34:48<3:28:28, 12.90s/it]

{'loss': 0.2712, 'grad_norm': 2.295351982116699, 'learning_rate': 1.375886524822695e-08, 'epoch': 0.94}


 32%|███▏      | 450/1410 [1:36:57<3:26:22, 12.90s/it]

{'loss': 0.4359, 'grad_norm': 3.8230137825012207, 'learning_rate': 1.3617021276595744e-08, 'epoch': 0.96}


 33%|███▎      | 460/1410 [1:39:07<3:25:10, 12.96s/it]

{'loss': 0.3194, 'grad_norm': 2.9421284198760986, 'learning_rate': 1.3475177304964538e-08, 'epoch': 0.98}


 33%|███▎      | 470/1410 [1:41:11<2:53:17, 11.06s/it]

{'loss': 0.2415, 'grad_norm': 4.339226245880127, 'learning_rate': 1.3333333333333334e-08, 'epoch': 1.0}


 34%|███▍      | 480/1410 [1:43:20<3:18:24, 12.80s/it]

{'loss': 0.3222, 'grad_norm': 3.6636478900909424, 'learning_rate': 1.3191489361702128e-08, 'epoch': 1.02}


 35%|███▍      | 490/1410 [1:45:30<3:17:41, 12.89s/it]

{'loss': 0.3436, 'grad_norm': 4.145240783691406, 'learning_rate': 1.3049645390070923e-08, 'epoch': 1.04}


 35%|███▌      | 500/1410 [1:47:39<3:15:45, 12.91s/it]

{'loss': 0.2084, 'grad_norm': 1.4678629636764526, 'learning_rate': 1.2907801418439717e-08, 'epoch': 1.06}


 36%|███▌      | 510/1410 [1:49:49<3:13:38, 12.91s/it]

{'loss': 0.2312, 'grad_norm': 2.3991997241973877, 'learning_rate': 1.2765957446808511e-08, 'epoch': 1.09}


 37%|███▋      | 520/1410 [1:51:58<3:11:32, 12.91s/it]

{'loss': 0.2704, 'grad_norm': 4.050826072692871, 'learning_rate': 1.2624113475177305e-08, 'epoch': 1.11}


 38%|███▊      | 530/1410 [1:54:08<3:10:12, 12.97s/it]

{'loss': 0.3061, 'grad_norm': 3.8242709636688232, 'learning_rate': 1.24822695035461e-08, 'epoch': 1.13}


 38%|███▊      | 540/1410 [1:56:19<3:08:22, 12.99s/it]

{'loss': 0.2329, 'grad_norm': 3.195796251296997, 'learning_rate': 1.2340425531914894e-08, 'epoch': 1.15}


 39%|███▉      | 550/1410 [1:58:29<3:05:10, 12.92s/it]

{'loss': 0.2835, 'grad_norm': 3.1722326278686523, 'learning_rate': 1.2198581560283688e-08, 'epoch': 1.17}


 40%|███▉      | 560/1410 [2:00:39<3:03:29, 12.95s/it]

{'loss': 0.2236, 'grad_norm': 3.997420310974121, 'learning_rate': 1.2056737588652482e-08, 'epoch': 1.19}


 40%|████      | 570/1410 [2:02:49<3:01:42, 12.98s/it]

{'loss': 0.1781, 'grad_norm': 2.4055466651916504, 'learning_rate': 1.1914893617021276e-08, 'epoch': 1.21}


 41%|████      | 580/1410 [2:04:58<2:58:56, 12.94s/it]

{'loss': 0.3628, 'grad_norm': 3.3016483783721924, 'learning_rate': 1.177304964539007e-08, 'epoch': 1.23}


 42%|████▏     | 590/1410 [2:07:08<2:57:05, 12.96s/it]

{'loss': 0.3096, 'grad_norm': 3.211503505706787, 'learning_rate': 1.1631205673758864e-08, 'epoch': 1.26}


 43%|████▎     | 600/1410 [2:09:18<2:55:09, 12.98s/it]

{'loss': 0.3123, 'grad_norm': 2.514883279800415, 'learning_rate': 1.1489361702127661e-08, 'epoch': 1.28}


 43%|████▎     | 610/1410 [2:11:28<2:52:48, 12.96s/it]

{'loss': 0.3328, 'grad_norm': 4.229534149169922, 'learning_rate': 1.1347517730496455e-08, 'epoch': 1.3}


 44%|████▍     | 620/1410 [2:13:37<2:50:29, 12.95s/it]

{'loss': 0.2792, 'grad_norm': 3.635962963104248, 'learning_rate': 1.120567375886525e-08, 'epoch': 1.32}


 45%|████▍     | 630/1410 [2:15:47<2:48:05, 12.93s/it]

{'loss': 0.2544, 'grad_norm': 2.4591751098632812, 'learning_rate': 1.1063829787234043e-08, 'epoch': 1.34}


 45%|████▌     | 640/1410 [2:17:57<2:46:42, 12.99s/it]

{'loss': 0.3273, 'grad_norm': 6.615872859954834, 'learning_rate': 1.0921985815602837e-08, 'epoch': 1.36}


 46%|████▌     | 650/1410 [2:20:07<2:43:43, 12.93s/it]

{'loss': 0.2867, 'grad_norm': 3.6989712715148926, 'learning_rate': 1.0780141843971631e-08, 'epoch': 1.38}


 47%|████▋     | 660/1410 [2:22:16<2:41:13, 12.90s/it]

{'loss': 0.3253, 'grad_norm': 3.2314815521240234, 'learning_rate': 1.0638297872340425e-08, 'epoch': 1.4}


 48%|████▊     | 670/1410 [2:24:26<2:39:50, 12.96s/it]

{'loss': 0.2953, 'grad_norm': 2.8293204307556152, 'learning_rate': 1.0496453900709219e-08, 'epoch': 1.43}


 48%|████▊     | 680/1410 [2:26:36<2:37:13, 12.92s/it]

{'loss': 0.2554, 'grad_norm': 2.7370870113372803, 'learning_rate': 1.0354609929078015e-08, 'epoch': 1.45}


 49%|████▉     | 690/1410 [2:28:46<2:35:17, 12.94s/it]

{'loss': 0.2418, 'grad_norm': 3.9628398418426514, 'learning_rate': 1.0212765957446808e-08, 'epoch': 1.47}


 50%|████▉     | 700/1410 [2:30:55<2:32:41, 12.90s/it]

{'loss': 0.3481, 'grad_norm': 1.9300477504730225, 'learning_rate': 1.0070921985815602e-08, 'epoch': 1.49}


 50%|█████     | 710/1410 [2:33:04<2:30:51, 12.93s/it]

{'loss': 0.2821, 'grad_norm': 3.3532238006591797, 'learning_rate': 9.929078014184398e-09, 'epoch': 1.51}


 51%|█████     | 720/1410 [2:35:14<2:28:33, 12.92s/it]

{'loss': 0.3174, 'grad_norm': 3.3241302967071533, 'learning_rate': 9.787234042553192e-09, 'epoch': 1.53}


 52%|█████▏    | 730/1410 [2:37:24<2:27:02, 12.97s/it]

{'loss': 0.2958, 'grad_norm': 7.728574275970459, 'learning_rate': 9.645390070921986e-09, 'epoch': 1.55}


 52%|█████▏    | 740/1410 [2:39:34<2:24:21, 12.93s/it]

{'loss': 0.2952, 'grad_norm': 2.3109922409057617, 'learning_rate': 9.503546099290781e-09, 'epoch': 1.57}


 53%|█████▎    | 750/1410 [2:41:44<2:22:59, 13.00s/it]

{'loss': 0.3382, 'grad_norm': 3.154874086380005, 'learning_rate': 9.361702127659575e-09, 'epoch': 1.6}


 54%|█████▍    | 760/1410 [2:43:54<2:20:59, 13.01s/it]

{'loss': 0.2201, 'grad_norm': 2.905703544616699, 'learning_rate': 9.21985815602837e-09, 'epoch': 1.62}


 55%|█████▍    | 770/1410 [2:46:04<2:18:06, 12.95s/it]

{'loss': 0.3505, 'grad_norm': 3.48734712600708, 'learning_rate': 9.078014184397163e-09, 'epoch': 1.64}


 55%|█████▌    | 780/1410 [2:48:14<2:16:24, 12.99s/it]

{'loss': 0.2571, 'grad_norm': 3.42716908454895, 'learning_rate': 8.936170212765957e-09, 'epoch': 1.66}


 56%|█████▌    | 790/1410 [2:50:24<2:13:57, 12.96s/it]

{'loss': 0.3148, 'grad_norm': 8.046710968017578, 'learning_rate': 8.794326241134753e-09, 'epoch': 1.68}


 57%|█████▋    | 800/1410 [2:52:33<2:11:10, 12.90s/it]

{'loss': 0.3023, 'grad_norm': 3.951936960220337, 'learning_rate': 8.652482269503547e-09, 'epoch': 1.7}


 57%|█████▋    | 810/1410 [2:54:43<2:09:36, 12.96s/it]

{'loss': 0.2445, 'grad_norm': 4.572831630706787, 'learning_rate': 8.51063829787234e-09, 'epoch': 1.72}


 58%|█████▊    | 820/1410 [2:56:53<2:07:33, 12.97s/it]

{'loss': 0.2931, 'grad_norm': 2.840482234954834, 'learning_rate': 8.368794326241135e-09, 'epoch': 1.74}


 59%|█████▉    | 830/1410 [2:59:03<2:05:25, 12.98s/it]

{'loss': 0.3262, 'grad_norm': 3.476762056350708, 'learning_rate': 8.226950354609929e-09, 'epoch': 1.77}


 60%|█████▉    | 840/1410 [3:01:13<2:03:08, 12.96s/it]

{'loss': 0.2151, 'grad_norm': 1.9626638889312744, 'learning_rate': 8.085106382978722e-09, 'epoch': 1.79}


 60%|██████    | 850/1410 [3:03:23<2:01:29, 13.02s/it]

{'loss': 0.1751, 'grad_norm': 2.623939275741577, 'learning_rate': 7.943262411347518e-09, 'epoch': 1.81}


 61%|██████    | 860/1410 [3:05:33<1:58:51, 12.97s/it]

{'loss': 0.2709, 'grad_norm': 2.35152006149292, 'learning_rate': 7.801418439716312e-09, 'epoch': 1.83}


 62%|██████▏   | 870/1410 [3:07:43<1:56:00, 12.89s/it]

{'loss': 0.3952, 'grad_norm': 3.101929187774658, 'learning_rate': 7.659574468085106e-09, 'epoch': 1.85}


 62%|██████▏   | 880/1410 [3:09:52<1:54:14, 12.93s/it]

{'loss': 0.2602, 'grad_norm': 2.123598098754883, 'learning_rate': 7.5177304964539e-09, 'epoch': 1.87}


 63%|██████▎   | 890/1410 [3:12:02<1:52:17, 12.96s/it]

{'loss': 0.2639, 'grad_norm': 2.9031379222869873, 'learning_rate': 7.375886524822695e-09, 'epoch': 1.89}


 64%|██████▍   | 900/1410 [3:14:12<1:49:49, 12.92s/it]

{'loss': 0.2736, 'grad_norm': 2.5881357192993164, 'learning_rate': 7.23404255319149e-09, 'epoch': 1.91}


 65%|██████▍   | 910/1410 [3:16:22<1:47:47, 12.93s/it]

{'loss': 0.384, 'grad_norm': 3.3363046646118164, 'learning_rate': 7.092198581560284e-09, 'epoch': 1.94}


 65%|██████▌   | 920/1410 [3:18:32<1:45:54, 12.97s/it]

{'loss': 0.258, 'grad_norm': 3.13215708732605, 'learning_rate': 6.950354609929078e-09, 'epoch': 1.96}


 66%|██████▌   | 930/1410 [3:20:42<1:43:35, 12.95s/it]

{'loss': 0.3316, 'grad_norm': 5.4038262367248535, 'learning_rate': 6.808510638297872e-09, 'epoch': 1.98}


 67%|██████▋   | 940/1410 [3:22:46<1:26:56, 11.10s/it]

{'loss': 0.2826, 'grad_norm': 3.9561991691589355, 'learning_rate': 6.666666666666667e-09, 'epoch': 2.0}


 67%|██████▋   | 950/1410 [3:24:56<1:38:52, 12.90s/it]

{'loss': 0.3667, 'grad_norm': 2.5829577445983887, 'learning_rate': 6.5248226950354616e-09, 'epoch': 2.02}


 68%|██████▊   | 960/1410 [3:27:06<1:37:00, 12.93s/it]

{'loss': 0.3099, 'grad_norm': 3.6017038822174072, 'learning_rate': 6.3829787234042555e-09, 'epoch': 2.04}


 69%|██████▉   | 970/1410 [3:29:15<1:34:45, 12.92s/it]

{'loss': 0.2809, 'grad_norm': 4.402766704559326, 'learning_rate': 6.24113475177305e-09, 'epoch': 2.06}


 70%|██████▉   | 980/1410 [3:31:25<1:32:34, 12.92s/it]

{'loss': 0.359, 'grad_norm': 2.4762418270111084, 'learning_rate': 6.099290780141844e-09, 'epoch': 2.09}


 70%|███████   | 990/1410 [3:33:35<1:31:22, 13.05s/it]

{'loss': 0.2567, 'grad_norm': 3.2493839263916016, 'learning_rate': 5.957446808510638e-09, 'epoch': 2.11}


 71%|███████   | 1000/1410 [3:35:45<1:28:30, 12.95s/it]

{'loss': 0.2933, 'grad_norm': 0.9806045889854431, 'learning_rate': 5.815602836879432e-09, 'epoch': 2.13}


 72%|███████▏  | 1010/1410 [3:37:55<1:26:05, 12.91s/it]

{'loss': 0.3059, 'grad_norm': 3.304204225540161, 'learning_rate': 5.673758865248228e-09, 'epoch': 2.15}


 72%|███████▏  | 1020/1410 [3:40:05<1:24:10, 12.95s/it]

{'loss': 0.2964, 'grad_norm': 3.2387781143188477, 'learning_rate': 5.531914893617022e-09, 'epoch': 2.17}


 73%|███████▎  | 1030/1410 [3:42:15<1:22:09, 12.97s/it]

{'loss': 0.2211, 'grad_norm': 2.555340528488159, 'learning_rate': 5.3900709219858155e-09, 'epoch': 2.19}


 74%|███████▍  | 1040/1410 [3:44:25<1:19:52, 12.95s/it]

{'loss': 0.3352, 'grad_norm': 5.9702911376953125, 'learning_rate': 5.2482269503546095e-09, 'epoch': 2.21}


 74%|███████▍  | 1050/1410 [3:46:35<1:17:41, 12.95s/it]

{'loss': 0.3204, 'grad_norm': 2.9228947162628174, 'learning_rate': 5.106382978723404e-09, 'epoch': 2.23}


 75%|███████▌  | 1060/1410 [3:48:44<1:15:05, 12.87s/it]

{'loss': 0.273, 'grad_norm': 2.0135598182678223, 'learning_rate': 4.964539007092199e-09, 'epoch': 2.26}


 76%|███████▌  | 1070/1410 [3:50:53<1:12:59, 12.88s/it]

{'loss': 0.3343, 'grad_norm': 1.9246928691864014, 'learning_rate': 4.822695035460993e-09, 'epoch': 2.28}


 77%|███████▋  | 1080/1410 [3:53:02<1:10:38, 12.84s/it]

{'loss': 0.2092, 'grad_norm': 2.4619996547698975, 'learning_rate': 4.680851063829788e-09, 'epoch': 2.3}


 77%|███████▋  | 1090/1410 [3:55:11<1:08:43, 12.89s/it]

{'loss': 0.3313, 'grad_norm': 1.899316430091858, 'learning_rate': 4.539007092198582e-09, 'epoch': 2.32}


 78%|███████▊  | 1100/1410 [3:57:20<1:06:29, 12.87s/it]

{'loss': 0.2674, 'grad_norm': 3.0258049964904785, 'learning_rate': 4.397163120567376e-09, 'epoch': 2.34}


 79%|███████▊  | 1110/1410 [3:59:29<1:04:15, 12.85s/it]

{'loss': 0.3905, 'grad_norm': 2.411266565322876, 'learning_rate': 4.25531914893617e-09, 'epoch': 2.36}


 79%|███████▉  | 1120/1410 [4:01:38<1:02:21, 12.90s/it]

{'loss': 0.3214, 'grad_norm': 2.0877304077148438, 'learning_rate': 4.113475177304964e-09, 'epoch': 2.38}


 80%|████████  | 1130/1410 [4:03:47<1:00:11, 12.90s/it]

{'loss': 0.3042, 'grad_norm': 3.4119675159454346, 'learning_rate': 3.971631205673759e-09, 'epoch': 2.4}


 81%|████████  | 1140/1410 [4:05:56<58:06, 12.91s/it]  

{'loss': 0.2136, 'grad_norm': 2.55631160736084, 'learning_rate': 3.829787234042553e-09, 'epoch': 2.43}


 82%|████████▏ | 1150/1410 [4:08:05<55:53, 12.90s/it]

{'loss': 0.2996, 'grad_norm': 2.486405849456787, 'learning_rate': 3.6879432624113473e-09, 'epoch': 2.45}


 82%|████████▏ | 1160/1410 [4:10:15<53:37, 12.87s/it]

{'loss': 0.2489, 'grad_norm': 3.591240644454956, 'learning_rate': 3.546099290780142e-09, 'epoch': 2.47}


 83%|████████▎ | 1170/1410 [4:12:24<51:17, 12.82s/it]

{'loss': 0.3362, 'grad_norm': 3.161414861679077, 'learning_rate': 3.404255319148936e-09, 'epoch': 2.49}


 84%|████████▎ | 1180/1410 [4:14:33<49:21, 12.88s/it]

{'loss': 0.2778, 'grad_norm': 4.795899391174316, 'learning_rate': 3.2624113475177308e-09, 'epoch': 2.51}


 84%|████████▍ | 1190/1410 [4:16:42<47:13, 12.88s/it]

{'loss': 0.2898, 'grad_norm': 1.5021796226501465, 'learning_rate': 3.120567375886525e-09, 'epoch': 2.53}


 85%|████████▌ | 1200/1410 [4:18:51<45:03, 12.87s/it]

{'loss': 0.2666, 'grad_norm': 2.0192489624023438, 'learning_rate': 2.978723404255319e-09, 'epoch': 2.55}


 86%|████████▌ | 1210/1410 [4:21:00<42:56, 12.88s/it]

{'loss': 0.2837, 'grad_norm': 2.947258710861206, 'learning_rate': 2.836879432624114e-09, 'epoch': 2.57}


 87%|████████▋ | 1220/1410 [4:23:10<41:00, 12.95s/it]

{'loss': 0.234, 'grad_norm': 2.3881194591522217, 'learning_rate': 2.6950354609929078e-09, 'epoch': 2.6}


 87%|████████▋ | 1230/1410 [4:25:19<38:56, 12.98s/it]

{'loss': 0.3402, 'grad_norm': 1.739137887954712, 'learning_rate': 2.553191489361702e-09, 'epoch': 2.62}


 88%|████████▊ | 1240/1410 [4:27:29<36:30, 12.88s/it]

{'loss': 0.2575, 'grad_norm': 2.6193735599517822, 'learning_rate': 2.4113475177304965e-09, 'epoch': 2.64}


 89%|████████▊ | 1250/1410 [4:29:38<34:19, 12.87s/it]

{'loss': 0.3383, 'grad_norm': 3.617649793624878, 'learning_rate': 2.269503546099291e-09, 'epoch': 2.66}


 89%|████████▉ | 1260/1410 [4:31:47<32:19, 12.93s/it]

{'loss': 0.2744, 'grad_norm': 0.03982473909854889, 'learning_rate': 2.127659574468085e-09, 'epoch': 2.68}


 90%|█████████ | 1270/1410 [4:33:57<30:07, 12.91s/it]

{'loss': 0.3501, 'grad_norm': 2.5984485149383545, 'learning_rate': 1.9858156028368795e-09, 'epoch': 2.7}


 91%|█████████ | 1280/1410 [4:36:06<28:01, 12.94s/it]

{'loss': 0.2134, 'grad_norm': 2.943904399871826, 'learning_rate': 1.8439716312056737e-09, 'epoch': 2.72}


 91%|█████████▏| 1290/1410 [4:38:16<25:50, 12.92s/it]

{'loss': 0.2477, 'grad_norm': 4.242483615875244, 'learning_rate': 1.702127659574468e-09, 'epoch': 2.74}


 92%|█████████▏| 1300/1410 [4:40:25<23:43, 12.94s/it]

{'loss': 0.3038, 'grad_norm': 2.574753999710083, 'learning_rate': 1.5602836879432626e-09, 'epoch': 2.77}


 93%|█████████▎| 1310/1410 [4:42:34<21:25, 12.85s/it]

{'loss': 0.2397, 'grad_norm': 2.763162136077881, 'learning_rate': 1.418439716312057e-09, 'epoch': 2.79}


 94%|█████████▎| 1320/1410 [4:44:44<19:19, 12.89s/it]

{'loss': 0.2637, 'grad_norm': 3.131744384765625, 'learning_rate': 1.276595744680851e-09, 'epoch': 2.81}


 94%|█████████▍| 1330/1410 [4:46:53<17:12, 12.90s/it]

{'loss': 0.3402, 'grad_norm': 2.8588297367095947, 'learning_rate': 1.1347517730496454e-09, 'epoch': 2.83}


 95%|█████████▌| 1340/1410 [4:49:02<15:00, 12.87s/it]

{'loss': 0.161, 'grad_norm': 1.5570167303085327, 'learning_rate': 9.929078014184398e-10, 'epoch': 2.85}


 96%|█████████▌| 1350/1410 [4:51:12<12:54, 12.91s/it]

{'loss': 0.2393, 'grad_norm': 2.7229833602905273, 'learning_rate': 8.51063829787234e-10, 'epoch': 2.87}


 96%|█████████▋| 1360/1410 [4:53:21<10:44, 12.89s/it]

{'loss': 0.3475, 'grad_norm': 6.58236026763916, 'learning_rate': 7.092198581560285e-10, 'epoch': 2.89}


 97%|█████████▋| 1370/1410 [4:55:31<08:37, 12.94s/it]

{'loss': 0.324, 'grad_norm': 3.146263360977173, 'learning_rate': 5.673758865248227e-10, 'epoch': 2.91}


 98%|█████████▊| 1380/1410 [4:57:40<06:27, 12.93s/it]

{'loss': 0.2159, 'grad_norm': 1.8104416131973267, 'learning_rate': 4.25531914893617e-10, 'epoch': 2.94}


 99%|█████████▊| 1390/1410 [4:59:50<04:18, 12.94s/it]

{'loss': 0.295, 'grad_norm': 7.297471523284912, 'learning_rate': 2.8368794326241135e-10, 'epoch': 2.96}


 99%|█████████▉| 1400/1410 [5:02:00<02:09, 12.96s/it]

{'loss': 0.2406, 'grad_norm': 2.1115171909332275, 'learning_rate': 1.4184397163120568e-10, 'epoch': 2.98}


100%|██████████| 1410/1410 [5:04:03<00:00, 12.94s/it]

{'loss': 0.2684, 'grad_norm': 7.029175758361816, 'learning_rate': 0.0, 'epoch': 3.0}
{'train_runtime': 18243.5609, 'train_samples_per_second': 1.853, 'train_steps_per_second': 0.077, 'train_loss': 0.28651398555606816, 'epoch': 3.0}





TrainOutput(global_step=1410, training_loss=0.28651398555606816, metrics={'train_runtime': 18243.5609, 'train_samples_per_second': 1.853, 'train_steps_per_second': 0.077, 'train_loss': 0.28651398555606816, 'epoch': 3.0})