In [1]:
from datasets import load_dataset, DatasetDict, Dataset
import pandas as pd
import ast
import datasets
from tqdm import tqdm
import time
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.optim.lr_scheduler import ExponentialLR

data_sample = load_dataset("QuyenAnhDE/Diseases_Symptoms")
updated_data = [{'Name': item['Name'], 'Symptoms': item['Symptoms'], 'Treatments': item['Treatments']} for item in data_sample['train']]
df = pd.DataFrame(updated_data)
print(df)
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
tokenizer = GPT2Tokenizer.from_pretrained('distilgpt2')

model = GPT2LMHeadModel.from_pretrained('distilgpt2').to(device)
BATCH_SIZE = 8

class LanguageDataset(Dataset):
    def __init__(self, df, tokenizer):
        self.labels = df.columns #устанавливаем метки столбцов
        self.data = df.to_dict(orient='records')
        self.tokenizer = tokenizer
        #x = self.average_len(df)
        self.max_length = 128 #в нашем лучае max_lenght  - средняя длина

    def average_len(self,df):
        sum_ = 0
        for example in df[self.labels[2]]:
          sum_ += len(example)
        x  = 2
        while x < sum_/len(df):
          x = x * 2
        return x
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        x = self.data[idx][self.labels[0]]
        y = self.data[idx][self.labels[1]]
        z = self.data[idx][self.labels[2]]
        text = f"{x} | {y} | {z}"

        tokens = self.tokenizer.encode_plus(text, return_tensors='pt', max_length=self.max_length, padding='max_length', truncation=True) 
        
        return tokens

data_sample = LanguageDataset(df, tokenizer)

train_size = int(0.8 * len(data_sample))
valid_size = len(data_sample) - train_size

train_data, valid_data = random_split(data_sample, [train_size, valid_size])
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True) #дополнительно перемешаем данные
valid_loader = DataLoader(valid_data, batch_size=BATCH_SIZE)
num_epochs = 20
batch_size = BATCH_SIZE
model_name = 'distilgpt2'
gpu = 0
reshuffle_every = 6
optimizer = optim.Adam(model.parameters())
tokenizer.pad_token = tokenizer.eos_token
results = pd.DataFrame(columns=['epoch', 'transformer', 'batch_size', 'gpu',
                                'training_loss', 'validation_loss', 'epoch_duration_sec'])

def reshuffle_data(dataset):
    train_data, test_data = random_split(dataset, [int(0.8 * len(dataset)), len(dataset) - int(0.8 * len(dataset))])
    return train_data, test_data

def train_model(model, num_epochs, train_loader, batch_size, model_name, sheduler, tokenizer, device):
  for epoch in range(num_epochs):
      start_time = time.time()  # Start the timer for the epoch
      #переводим модель в режим обучения
      # if epoch % reshuffle_every == 0:
      #   train_data, test_data = reshuffle_data(data_sample)
      #   train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
      #   test_loader = DataLoader(test_data, batch_size=32, shuffle=False)
      model.train()
      epoch_training_loss = 0

      train_iterator = tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs} Batch Size: {batch_size}, Transformer: {model_name}")

      for batch in train_iterator:
          optimizer.zero_grad()
          inputs = batch['input_ids'].squeeze(1).to(device)
          targets = inputs.clone()

          outputs = model(input_ids=inputs, labels=targets)

          loss = outputs.loss
          
          #выполняем обратный переход
          loss.backward()
          #обновляем веса
          optimizer.step()

          train_iterator.set_postfix({'Training Loss': loss.item()})
          epoch_training_loss += loss.item()

      avg_epoch_training_loss = epoch_training_loss / len(train_iterator)

      #переводим модель в режим ответов
      model.eval()
      
      epoch_validation_loss = 0
      total_loss = 0
      valid_iterator = tqdm(valid_loader, desc=f"Validation Epoch {epoch+1}/{num_epochs}")
      with torch.no_grad():
          for batch in valid_iterator:
              inputs = batch['input_ids'].squeeze(1).to(device)
              targets = inputs.clone()
              outputs = model(input_ids=inputs, labels=targets)
              loss = outputs.loss
              total_loss += loss
              valid_iterator.set_postfix({'Validation Loss': loss.item()})
              epoch_validation_loss += loss.item()

      avg_epoch_validation_loss = epoch_validation_loss / len(valid_loader)

      end_time = time.time()  # закончилась одна эпоха
      epoch_duration_sec = end_time - start_time

      new_row = {'transformer': model_name,
                'batch_size': batch_size,
                'gpu': gpu,
                'epoch': epoch+1,
                'training_loss': avg_epoch_training_loss,
                'validation_loss': avg_epoch_validation_loss,
                'epoch_duration_sec': epoch_duration_sec}  

      results.loc[len(results)] = new_row
      print(f"Epoch: {epoch+1}, Validation Loss: {total_loss/len(valid_loader)}")

      print('last lr', sheduler.get_last_lr())
      sheduler.step()

sheduler  =  ExponentialLR(optimizer, gamma=0.85)
train_model(model, num_epochs, train_loader, batch_size, model_name, sheduler, tokenizer, device)
input_str = "Panic disorder"
input_ids = tokenizer.encode(input_str, return_tensors='pt').to(device)

output = model.generate(
    input_ids,
    max_length=70,
    num_return_sequences=1,
    do_sample=True,
    top_k=10,
    top_p=0.8,
    temperature=1,
    repetition_penalty=1.2
)
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
print(decoded_output)

  from .autonotebook import tqdm as notebook_tqdm
Repo card metadata block was not found. Setting CardData to empty.


                               Name  \
0                    Panic disorder   
1                  Vocal cord polyp   
2                   Turner syndrome   
3                    Cryptorchidism   
4       Ethylene glycol poisoning-1   
..                              ...   
395  Urinary Stones (Kidney Stones)   
396                    Osteoporosis   
397            Rheumatoid Arthritis   
398                 Type 1 Diabetes   
399                 Type 2 Diabetes   

                                              Symptoms  \
0    Palpitations, Sweating, Trembling, Shortness o...   
1             Hoarseness, Vocal Changes, Vocal Fatigue   
2    Short stature, Gonadal dysgenesis, Webbed neck...   
3    Absence or undescended testicle(s), empty scro...   
4    Nausea, vomiting, abdominal pain, General mala...   
..                                                 ...   
395  Severe abdominal or back pain, blood in urine,...   
396  Fragile bones, loss of height over time, back ...   
397  Join

Training Epoch 1/20 Batch Size: 8, Transformer: distilgpt2:   0%|          | 0/40 [00:00<?, ?it/s]`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Training Epoch 1/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.81s/it, Training Loss=1.74] 
Validation Epoch 1/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.3] 


Epoch: 1, Validation Loss: 1.4001721143722534
last lr [0.001]


Training Epoch 2/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:41<00:00,  4.03s/it, Training Loss=1.14] 
Validation Epoch 2/20: 100%|██████████| 10/10 [00:11<00:00,  1.16s/it, Validation Loss=1.2] 


Epoch: 2, Validation Loss: 1.3118444681167603
last lr [0.00085]


Training Epoch 3/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:35<00:00,  3.89s/it, Training Loss=0.7]  
Validation Epoch 3/20: 100%|██████████| 10/10 [00:11<00:00,  1.18s/it, Validation Loss=1.23]


Epoch: 3, Validation Loss: 1.3455476760864258
last lr [0.0007224999999999999]


Training Epoch 4/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.82s/it, Training Loss=0.654]
Validation Epoch 4/20: 100%|██████████| 10/10 [00:11<00:00,  1.14s/it, Validation Loss=1.27]


Epoch: 4, Validation Loss: 1.4000673294067383
last lr [0.000614125]


Training Epoch 5/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.651]
Validation Epoch 5/20: 100%|██████████| 10/10 [00:11<00:00,  1.16s/it, Validation Loss=1.32]


Epoch: 5, Validation Loss: 1.4912079572677612
last lr [0.00052200625]


Training Epoch 6/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.80s/it, Training Loss=0.352]
Validation Epoch 6/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.4] 


Epoch: 6, Validation Loss: 1.571752905845642
last lr [0.00044370531249999997]


Training Epoch 7/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.24] 
Validation Epoch 7/20: 100%|██████████| 10/10 [00:11<00:00,  1.14s/it, Validation Loss=1.52]


Epoch: 7, Validation Loss: 1.7179937362670898
last lr [0.00037714951562499996]


Training Epoch 8/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.81s/it, Training Loss=0.278]
Validation Epoch 8/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.59]


Epoch: 8, Validation Loss: 1.810288667678833
last lr [0.00032057708828124994]


Training Epoch 9/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.197]
Validation Epoch 9/20: 100%|██████████| 10/10 [00:11<00:00,  1.16s/it, Validation Loss=1.76]


Epoch: 9, Validation Loss: 1.9638960361480713
last lr [0.0002724905250390624]


Training Epoch 10/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.80s/it, Training Loss=0.13] 
Validation Epoch 10/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.76]


Epoch: 10, Validation Loss: 1.9797369241714478
last lr [0.00023161694628320305]


Training Epoch 11/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.0892]
Validation Epoch 11/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.82]


Epoch: 11, Validation Loss: 2.0539543628692627
last lr [0.0001968744043407226]


Training Epoch 12/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.81s/it, Training Loss=0.0875]
Validation Epoch 12/20: 100%|██████████| 10/10 [00:11<00:00,  1.18s/it, Validation Loss=1.87]


Epoch: 12, Validation Loss: 2.087176561355591
last lr [0.0001673432436896142]


Training Epoch 13/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:31<00:00,  3.80s/it, Training Loss=0.104] 
Validation Epoch 13/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.85]


Epoch: 13, Validation Loss: 2.093751907348633
last lr [0.00014224175713617207]


Training Epoch 14/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.061] 
Validation Epoch 14/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.88]


Epoch: 14, Validation Loss: 2.109910488128662
last lr [0.00012090549356574625]


Training Epoch 15/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.81s/it, Training Loss=0.049] 
Validation Epoch 15/20: 100%|██████████| 10/10 [00:11<00:00,  1.16s/it, Validation Loss=1.91]


Epoch: 15, Validation Loss: 2.130234479904175
last lr [0.00010276966953088431]


Training Epoch 16/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.82s/it, Training Loss=0.0539]
Validation Epoch 16/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.91]


Epoch: 16, Validation Loss: 2.13869047164917
last lr [8.735421910125166e-05]


Training Epoch 17/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.82s/it, Training Loss=0.0403]
Validation Epoch 17/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.89]


Epoch: 17, Validation Loss: 2.1288273334503174
last lr [7.425108623606391e-05]


Training Epoch 18/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.0545]
Validation Epoch 18/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.9]


Epoch: 18, Validation Loss: 2.143779993057251
last lr [6.311342330065433e-05]


Training Epoch 19/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:33<00:00,  3.83s/it, Training Loss=0.046] 
Validation Epoch 19/20: 100%|██████████| 10/10 [00:11<00:00,  1.15s/it, Validation Loss=1.92]


Epoch: 19, Validation Loss: 2.142712116241455
last lr [5.3646409805556176e-05]


Training Epoch 20/20 Batch Size: 8, Transformer: distilgpt2: 100%|██████████| 40/40 [02:32<00:00,  3.81s/it, Training Loss=0.0942]
Validation Epoch 20/20: 100%|██████████| 10/10 [00:11<00:00,  1.17s/it, Validation Loss=1.93]
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


Epoch: 20, Validation Loss: 2.1489338874816895
last lr [4.559944833472275e-05]
Panic disorder | Palpitations, Sweating, Trembling, Shortness of breath, Fear of losing control, Dizziness | Antidepressant medications, Cognitive Behavioral Therapy, Relaxation Techniques
