In [1]:
from gpt2.modeling_gpt2 import GPT2LMHeadModel
import torch
from torch.utils.data.dataset import Dataset
from torch.utils.data.dataloader import DataLoader
from transformers import BertTokenizer, get_linear_schedule_with_warmup
from tqdm import tqdm
from datetime import datetime
import os
import numpy as np
import logging
import random
from torch.utils.tensorboard import SummaryWriter

In [14]:
def seed_everything(seed: int = 42):
    """Util to make training reproducible"""
    random.seed(seed)

    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if os.getenv("CUBLAS_WORKSPACE_CONFIG") is not None:
        torch.use_deterministic_algorithms(True)
        os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def worker_init(worked_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

def set_logger(path):

    logger = logging.getLogger()
    handler = logging.FileHandler(path + "/train_log.txt")
    logger.setLevel(level=logging.INFO)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter(
        "%(asctime)s - %(filename)s - %(funcName)s - %(lineno)s - %(levelname)s\n%(message)s",
        "%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger.addHandler(handler)
    logger.addHandler(console)


class Chinese_Medical_DS(Dataset):
    def __init__(self, path, tokenizer, max_len=1024):
        self.path = path
        sentence = []
        self.private_positions_list = []
        
        self.total_private_tokens = 0
        self.total_tokens = 0
        self.private_tokens_per_sentence = []
        
        with open(self.path, 'r', encoding='utf-8') as f:
            for line in f:
                sen_ids = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(line.strip()))
                full_sen = []
                
                full_sen.append(tokenizer.convert_tokens_to_ids('[MASK]'))
                full_sen.extend(sen_ids)
                full_sen.append(tokenizer.convert_tokens_to_ids('[CLS]'))
                private_positions = [i for i, token_id in enumerate(full_sen) if tokenizer.decode([token_id]).isdigit()]
                self.private_positions_list.append(private_positions)
                if len(full_sen) <= max_len:
                    sentence.append(full_sen)
                
                num_private_tokens = len(private_positions)
                self.total_private_tokens += num_private_tokens
                self.total_tokens += len(full_sen)
                self.private_tokens_per_sentence.append(num_private_tokens)
                
        self.data = sentence
        
        # 计算平均每句话中的隐私token数
        average_private_tokens_per_sentence = sum(self.private_tokens_per_sentence) / len(self.private_tokens_per_sentence)

        # 计算总的隐私token占总token的比例
        private_token_ratio = self.total_private_tokens / self.total_tokens

        # 计算隐私token个数的均值和方差
        private_token_mean = np.mean(self.private_tokens_per_sentence)
        private_token_variance = np.var(self.private_tokens_per_sentence)

        # 打印统计信息
        print(f"Average private tokens per sentence: {average_private_tokens_per_sentence:.2f}")
        print(f"Private token ratio: {private_token_ratio:.4f}")
        print(f"Private token count mean: {private_token_mean:.2f}")
        print(f"Private token count variance: {private_token_variance:.2f}")
        
    # need to overload
    def __len__(self):
        return len(self.data)

    # need to overload
    def __getitem__(self, idx):
        input = self.data[idx]
        target = input
        private_positions = self.private_positions_list[idx]
        return input, target, private_positions
    
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""

    def __init__(self, save_path, patience=2, verbose=True, delta=0):
        """
        Args:
            save_path : 模型保存文件夹
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement.
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
        """
        self.save_path = save_path
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f"EarlyStopping counter: {self.counter} out of {self.patience}")
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        """Saves model when validation loss decrease."""
        if self.verbose:
            print(
                f"Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ..."
            )
        model_to_save = model.module if hasattr(model, 'module') else model
        model_to_save.save_pretrained(self.save_path + 'best_model')
        
        # path = os.path.join(self.save_path, "best_network.pth")
        # torch.save(model.state_dict(), path)  # 这里会存储迄今最优模型的参数
        self.val_loss_min = val_loss

In [15]:
seed_everything()
tok_path = '..\\..\\Raw_GPT2\\vocab.txt'
pretrain_model_path = "..\\..\\Raw_GPT2\\"
output_dir = "model\\"

epochs = 50
warmup_steps = 1000
lr = 1e-5
gradient_accumulation = 18
max_grad_norm = 1.0
log_step = 10000
set_logger(output_dir)
logger = logging.getLogger(__name__)


In [16]:
tokenizer = BertTokenizer(vocab_file=tok_path)
model = GPT2LMHeadModel.from_pretrained(pretrain_model_path)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
logger.info('using device:{}'.format(device))
model.train()
model.to(device)
logger.info(model)

GPT2LMHeadModel::init
config =  GPT2Config {
  "_name_or_path": "..\\..\\Raw_GPT2\\",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "gradient_checkpointing": false,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "output_past": true,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 320
    }
  },
  "tokenizer_class": "BertTokenizer",
  "transformers_version": "4.24.0",
  "use_cache

2023-04-01 14:36:42 - 945538542.py - <module> - 4 - INFO
using device:cuda
2023-04-01 14:36:42 - 945538542.py - <module> - 4 - INFO
using device:cuda
2023-04-01 14:36:42 - 945538542.py - <module> - 4 - INFO
using device:cuda
2023-04-01 14:36:43 - 945538542.py - <module> - 7 - INFO
GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(21128, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
   

In [17]:
train_dataset = Chinese_Medical_DS("..\\..\\Data\\tiny_train.txt", tokenizer)
train_dataloader = DataLoader(dataset=train_dataset, worker_init_fn=worker_init)
valid_dataset = Chinese_Medical_DS("..\\..\\Data\\tiny_valid.txt", tokenizer)
valid_dataloader = DataLoader(dataset=valid_dataset, worker_init_fn=worker_init)
logger.info("len(train_dataloader), len(valid_dataloader) = {}, {}".format(len(train_dataloader), len(valid_dataloader)))

Average private tokens per sentence: 1.22
Private token ratio: 0.0055
Private token count mean: 1.22
Private token count variance: 4.91


2023-04-01 14:37:13 - 3687209076.py - <module> - 5 - INFO
len(train_dataloader), len(valid_dataloader) = 9000, 1500
2023-04-01 14:37:13 - 3687209076.py - <module> - 5 - INFO
len(train_dataloader), len(valid_dataloader) = 9000, 1500
2023-04-01 14:37:13 - 3687209076.py - <module> - 5 - INFO
len(train_dataloader), len(valid_dataloader) = 9000, 1500


Average private tokens per sentence: 1.22
Private token ratio: 0.0055
Private token count mean: 1.22
Private token count variance: 5.27


In [8]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                          num_training_steps=len(train_dataloader))
tb_path = output_dir + "/tb"
if not os.path.exists(tb_path):
    os.mkdir(tb_path)
writer = SummaryWriter(tb_path)

In [10]:
train_num_log = len(train_dataloader) * np.log(len(train_dataloader))
delta = 1.0 / train_num_log if 1.0 / train_num_log < 1e-5 else 1e-5
epsilon = 0.5
sigma = 0.810546875
epsilon, delta, sigma

(0.5, 1e-05, 0.810546875)

In [13]:
running_loss = 0
early_stopping = EarlyStopping(output_dir)
train_step_per_epoch = len(train_dataloader)
valid_step_per_epoch = len(valid_dataloader)
for epoch in range(epochs):
    logger.info('epoch {}'.format(epoch + 1))
    now = datetime.now()
    logger.info('time: {}'.format(now))
    model.train()
    train_pbar = tqdm(train_dataloader)
    all_train_loss = 0.0
    train_pbar.set_description('epoch-' + str(epoch + 1))
    
    for step, (input, label, private_positions) in enumerate(train_pbar):
        input_ids = torch.tensor(label).long().to(device)
        label_ids = torch.tensor(input).long().to(device)
        private_positions = torch.tensor(private_positions, dtype=torch.long).to(device)

        #  forward pass
        outputs = model(input_ids=input_ids, labels=label_ids, private_positions=private_positions, sigma=sigma)
        loss, logits = outputs[:2]
        
        if gradient_accumulation > 1:
            loss = loss / gradient_accumulation
            
        #  loss backward
        # if fp16:
        #     with amp.scale_loss(loss, optimizer) as scaled_loss:
        #         scaled_loss.backward()
        #         torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
        # else:
        #     loss.backward()
        #     torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
        
        loss.backward()
        loss = loss.detach()
        all_train_loss += loss
        
        writer.add_scalar('loss/train_step_loss', scalar_value=loss * gradient_accumulation, global_step=epoch * train_step_per_epoch+step)
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

        #  optimizer step
        if (step + 1) % gradient_accumulation == 0:
            running_loss += loss.item()
            optimizer.step()
            optimizer.zero_grad()
            scheduler.step()
        if (step + 1) % log_step == 0:
            logger.info('now time: {}:{}. Step {} of epoch {}, loss {}'.format(
                datetime.now().hour,
                datetime.now().minute,
                (step + 1) // gradient_accumulation,
                epoch + 1,
                running_loss / log_step))
            running_loss = 0
        
        train_pbar.set_postfix({'loss': '{:.7f}'.format(loss*gradient_accumulation)})
        
    logger.info('train step = {}'.format(step))
    all_train_loss = all_train_loss / (step + 1)

    writer.add_scalar('loss/train_epoch_loss', scalar_value=all_train_loss * gradient_accumulation, global_step=epoch + 1)
    logger.info('saving model for epoch {}'.format(epoch + 1))
    if not os.path.exists(output_dir + 'model_epoch{}'.format(epoch + 1)):
        os.mkdir(output_dir + 'model_epoch{}'.format(epoch + 1))
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(output_dir + 'model_epoch{}'.format(epoch + 1))

    logger.info('epoch {} finished, train loss = {:.10f}'.format(epoch + 1, all_train_loss * gradient_accumulation))

    then = datetime.now()
    logger.info('time: {}'.format(then))
    logger.info('time for one epoch: {}'.format(then - now))
    
    logger.info('start validate')
    model.eval()
    all_valid_loss = 0.0
    valid_pbar = tqdm(valid_dataloader)
    valid_pbar.set_description('valid ' + str(epoch + 1))
    for step, (input, label) in enumerate(valid_pbar):
        input_ids = torch.tensor(label).long().to(device)
        label_ids = torch.tensor(input).long().to(device)

        #  forward pass
        outputs = model(input_ids=input_ids, labels=label_ids)
        loss = outputs[0].detach()
        writer.add_scalar('loss/valid_step_loss', scalar_value=loss, global_step=epoch * valid_step_per_epoch + step)
        all_valid_loss += loss
        valid_pbar.set_postfix({'loss': '{:.7f}'.format(loss)})
    
    logger.info('valid step = {}'.format(step))
    all_valid_loss = all_valid_loss / (step + 1)
    writer.add_scalar('loss/valid_epoch_loss', scalar_value=all_valid_loss, global_step=epoch+1)
    logger.info('valid finished, valid loss = {:.10f}'.format(all_valid_loss))
    early_stopping(all_valid_loss, model)
    if early_stopping.early_stop:
        logger.info("Early stopping")
        break

writer.close()    

logger.info('training finished')
if not os.path.exists(output_dir + 'final_model'):
    os.mkdir(output_dir + 'final_model')
model_to_save = model.module if hasattr(model, 'module') else model
model_to_save.save_pretrained(output_dir + 'final_model')

2023-04-01 14:31:40 - 2840535088.py - <module> - 6 - INFO
epoch 1
2023-04-01 14:31:40 - 2840535088.py - <module> - 6 - INFO
epoch 1
2023-04-01 14:31:40 - 2840535088.py - <module> - 8 - INFO
time: 2023-04-01 14:31:40.233154
2023-04-01 14:31:40 - 2840535088.py - <module> - 8 - INFO
time: 2023-04-01 14:31:40.233154
epoch-1:   0%|          | 0/9000 [00:00<?, ?it/s, loss=4.0664477]

input_ids =  tensor([ 103, 2140, 2140,  671, 4684, 4717,  833, 3221, 1538, 3694, 1036, 8043,
        2769, 2157, 2140, 2140,  738, 3221, 3300,  673,  702, 3299,  749, 8024,
         702, 2094, 2798, 8399, 8024,  860, 7028, 2798, 1282, 1061, 3165, 8024,
        5307, 2382, 2476, 2458, 1673, 2349,  679, 5052, 3221, 4717, 1286, 6230,
        6820, 3221, 4635, 1921, 8024, 2769, 6574, 4542,  800, 3221, 1538, 3694,
        1036, 8024, 6821, 1008, 1408, 8043, 6716, 3332, 4765, 2207, 8024, 1928,
        1741,  856,  754, 3633, 2382, 8024, 1928, 1184,  510, 1400, 2520, 4764,
        8024, 3359, 6956, 2398, 1439, 2793, 1928,  511, 7568, 4764,  510, 4649,
        5502, 2160, 3351,  511, 7755, 7977, 2382, 5862, 1400,  754, 2399, 7977,
        8024, 1139, 4280, 2454, 1400,  684, 2382, 7231,  855,  511, 1928, 1355,
        5301, 6763, 5445, 6772, 2208,  511, 1184, 1727, 1394, 2879, 3241, 8024,
        7553, 3359,  704, 5296, 1377, 3300, 5018,  753, 1727, 7305,  511, 1724,
        5501, 4764, 8024, 4

epoch-1:   0%|          | 1/9000 [00:00<32:41,  4.59it/s, loss=2.8775215]

input_ids =  tensor([ 103, 2111, 2094, 2793, 3425,  860, 1355, 4142, 1355, 4173, 2582, 3416,
        1394, 4415, 7650, 7608, 8043, 2769, 2157, 4638, 2111, 2094, 3221, 4511,
        2140, 2140, 8024,  124, 2259, 8024, 1157, 2458, 1993, 8024, 1624, 2094,
        4706, 3300, 4157, 4578, 8024, 1355, 4385, 8024,  845, 3300, 6768, 2544,
        4638, 1495, 1644, 8024, 1369, 1912, 8024, 3300, 4157, 1355, 4173,  738,
        3766, 5125, 4868, 8024, 6435, 7309, 8038, 2111, 2094, 2793, 3425,  860,
        1355, 4142, 1355, 4173, 2582, 3416, 1394, 4415, 7650, 7608,  511, 1355,
        4385, 2793, 3425,  860, 4142, 1218, 2553, 6206, 1350, 3198, 5314, 2111,
        2094, 3780, 4545, 8024, 1369, 1912, 3189, 2382, 7650, 7608, 1377,  809,
        1914, 1391,  671,  763, 5922, 5831, 1469, 3717, 3362, 8024, 7370,  749,
        3926, 3909,  679, 1173, 4080,  722, 1912, 8024, 6821, 5102, 7608, 4289,
        4638, 5852, 1075, 5162,  738, 3221, 3683, 6772,  705, 2168, 4638, 8024,
        3300, 1221,  754, 2

epoch-1:   0%|          | 5/9000 [00:00<14:41, 10.21it/s, loss=2.7803459]

input_ids =  tensor([ 103, 2207, 2140, 2140,  679, 4263, 1391, 7649, 2418, 6421, 2582,  720,
        1215, 8043, 2140, 2140,  671, 2259, 1288, 8024,  679, 4263, 1391, 7649,
        2418, 6421, 2582,  720, 1215, 8043, 2140, 6564, 4385, 1762, 3221,  122,
        2259, 8024,  679, 4263, 1391, 7649, 8024, 1328, 7608, 3683, 6772,  698,
        7028, 8024, 6821, 4905, 2658, 1105, 2418, 6421, 2582,  720, 1215, 1962,
        8043,  122, 2259, 1288, 2140, 2140,  679, 4263, 1391, 7649, 2418, 6421,
        2582,  720, 1215, 1450, 8043, 2644, 1962, 8024, 2140, 2140,  679, 4263,
        1391, 7649, 1333, 1728, 3300, 2523, 1914,  511, 5375, 7159, 5375, 7227,
         510, 1585, 1075,  679, 2496,  510, 3698,  952, 2512, 1510, 6963, 3221,
        2382, 6224, 4638, 1333, 1728,  511, 1968, 1968, 6206, 1914, 1217, 4522,
        2692, 8024, 6912, 1048, 2140, 2140, 1728, 6716,  860,  679, 6844, 7270,
        3309,  679, 1391, 7649,  511,  671, 3190, 2140, 2140, 1139, 4385, 7608,
        3617,  679, 2920,  

epoch-1:   0%|          | 7/9000 [00:00<12:53, 11.63it/s, loss=3.8173072]

input_ids =  tensor([ 103, 2207, 1036, 7481, 4611, 7151, 4132,  833, 4578, 1408, 8043, 2111,
        2094, 1920, 3519, 3221, 1762, 8108, 1384, 4638, 3198,  952,  678, 7433,
        1921, 1107, 1168, 1912, 7481, 1430, 6814, 7599, 8024, 1071,  800, 3766,
        3300,  679, 5653, 3302, 8024, 1962, 1008, 3221, 5018,  123, 1921, 2218,
        1355, 4385, 1381, 1673, 3639,  749, 2340, 4706, 4714, 7308,  679, 5165,
        8024, 7961, 5586, 2094, 2340, 6804, 4026, 3698, 8024, 2340, 6804, 1391,
         691, 6205,  738,  833, 1853,  511, 8124, 1384, 2458, 1993,  857, 7368,
        8024, 2802, 1396, 7151, 8024, 1400, 3341, 1126, 1921, 6820,  698, 7028,
         749, 6206, 7151, 4132, 8024, 2207, 1036, 7151,  833, 4578, 1408, 8043,
        7481, 4611, 2642, 5442, 2398, 3198, 2418, 3800, 2692, 7564, 7344, 8024,
         924, 2898, 5125, 4868, 2690, 2571, 8024,  924, 6395, 6844, 2496, 4638,
        4717, 4697, 1469,  828, 2622,  511, 1915, 7313, 6912, 1048, 1358, 1107,
        7599,  909, 6159,  

epoch-1:   0%|          | 11/9000 [00:00<11:39, 12.84it/s, loss=3.1658878]

input_ids =  tensor([ 103, 2048, 1036, 4570, 2908, 4568, 6206, 2582,  720, 2902, 3040, 8043,
        2769, 1995, 1995, 3221,  671,  855, 2802, 2339,  798, 8024, 2382, 2399,
        1139, 1912, 2802, 7439, 2339, 8024, 4385, 1762, 4495, 6814, 2111, 2094,
         749, 8024, 1377, 3221, 2111, 2094, 2642, 1196, 4164, 4563, 4578, 8024,
        1762, 1278, 7368,  976,  749, 3131, 3780, 8024, 2644, 1962, 2048, 1036,
        4570, 2908, 4568, 2582,  720, 2972, 2897, 8043, 2797, 6956, 3013, 2987,
        6237, 1104,  131, 6768, 6768, 2861, 4684, 2797, 2900,  511, 3030, 1220,
         872, 4638, 2797, 2900, 8024, 6768, 2902,  872, 4638, 2900, 2211, 8024,
        2972, 2897,  872, 4638, 5491, 5489,  511, 3131, 3780, 2207, 5597, 2853,
        5025,  131, 1348, 4917, 2207, 5597, 2853, 5025, 8024, 1348, 4917, 5578,
        5499, 5491, 5491, 5588,  511, 1920, 5597,  131, 1777, 1762, 1765, 3352,
         677, 8024, 1920, 5597,  847, 2398, 8024, 4197, 1400,  678, 1327, 5607,
        4667, 8024, 1920, 5

epoch-1:   0%|          | 13/9000 [00:01<10:49, 13.84it/s, loss=2.2564116]

input_ids =  tensor([ 103, 1063, 2259, 2207, 2111, 4706, 4714,  679,  977, 4699, 2582,  720,
        1215, 8043, 1063, 2259, 2207, 1036, 4706, 4714, 4699, 2582,  720, 1215,
        1343, 3389,  671,  678, 6228, 1213, 8024, 4692,  833,  679,  833, 2235,
        1045,  679, 3633,  511, 6206, 2831, 5165, 8024, 2207, 1377,  809, 4763,
        3633, 8024, 7370,  749, 1350, 3198, 3780, 4545, 2207, 1036, 1912, 4706,
        4567, 1912, 8024, 2642, 5442, 7444, 6206, 1914, 1486, 6418,  683, 2157,
        2456, 6379, 8024, 1469, 1278, 4495,  924, 2898, 3765, 6858, 8024, 2642,
        5442, 6820, 7444, 6206, 7028, 6228, 7650, 7608, 3175, 7481, 8024,  891,
        1963, 3189, 2382, 3926, 3909, 7650, 7608, 8024, 1914, 1912, 1139, 6817,
        1220,  511,  680, 3634, 1398, 3198, 2642, 5442, 6820, 6206, 3800, 2692,
        6848, 2885,  671, 2157, 3633, 6226, 1278, 7368, 6402, 3780, 8024, 6821,
        3416, 2798, 5543, 2533, 1168, 5679, 1962, 4638, 3780, 4545, 3126, 3362,
         511,  101], device

epoch-1:   0%|          | 17/9000 [00:01<10:48, 13.86it/s, loss=3.3116119]

input_ids =  tensor([ 103, 1044, 4495, 3381, 5682,  860,  679, 3633, 2382, 2582,  720, 3780,
        4545, 8043, 3766, 3300, 4568, 4307, 1355, 4567, 3198, 7313, 1350, 1333,
        1728, 1377, 5543, 3221, 1921, 4495, 4638, 4567, 2658, 1146, 3358, 8038,
        6821,  702, 2418, 6421, 3221, 1044, 1921, 4638, 8024, 2682, 1343, 2957,
        3221, 2523, 7410, 4638, 8013, 1377,  809, 1168,  683, 7305, 4638, 4511,
        2595,  683, 4906, 1278, 7368, 1343, 3466, 3389,  678, 4638,  679, 6206,
        6814,  754, 4638, 2857, 2552, 8013, 2900, 2193, 2692, 6224, 8038, 2456,
        6379, 2644, 4638, 1044, 4495, 1168, 1278, 7368, 1343, 3466, 3389,  678,
        8024, 3418, 2945, 3466, 3389, 5310, 3362, 6822, 6121, 3780, 4545, 8013,
        1147, 2555, 7390,  912, 3302, 5790, 8024, 1963, 3362, 3302, 5790, 6435,
        6905, 1278, 1671, 8013, 1398, 3198, 6206,  924, 2898,  671,  702, 5679,
        1962, 4638, 2552, 2578, 8024, 4916, 3353, 4638, 2552, 2578, 2190, 4565,
        4567, 4638, 3780, 4

epoch-1:   0%|          | 19/9000 [00:01<11:27, 13.05it/s, loss=2.9805856]

input_ids =  tensor([ 103, 1495, 1644,  862, 3198, 2798, 5543, 1962, 8043, 2207, 2111,  123,
        2259, 1288, 8024,  671,  702, 1914, 3299, 1184, 3466, 3389, 3221,  845,
        3300, 1596, 2622, 2595, 4638, 3118, 3698, 5052, 4142,  872, 4638, 6821,
        4905, 2658, 1105, 1377, 5543, 3221, 7599, 2170, 2697, 1088, 1495, 1644,
        8024, 7599, 2170, 1495, 1644,  100, 8024, 4567, 6395, 1399,  511, 2697,
        1358, 7599, 2170, 2792, 5636, 4638, 1495, 1644,  511, 3315, 6395,  809,
        1495, 1644, 7574,  868, 8024, 4588, 4921, 5682, 4635, 8024, 7965, 3837,
        3926, 3873, 8024, 5649, 3909, 5273, 4635, 8024, 5726, 5946, 1495, 1644,
        1898, 7028, 8024, 1928, 4578, 6716, 4178, 8024, 4493, 1156, 1596, 2593,
         711, 4294, 2519,  511, 1377,  809, 4500,  676, 2871, 3739, 1217, 7360,
        3698, 1265, 4588, 4638, 5790, 4289, 3780, 4545, 8024, 6863, 2768, 2207,
        1036, 1461, 1429, 5143, 5320, 4638, 1728, 5162, 6772, 1914, 8024, 1963,
        3362, 2111, 2094, 6

epoch-1:   0%|          | 23/9000 [00:01<11:28, 13.05it/s, loss=3.1657641]

input_ids =  tensor([ 103, 2207, 1036, 2793, 3425,  860, 4142, 4563, 6421, 2582, 3416, 3780,
        3126, 3362, 1962, 8043, 2769, 2157, 1957, 2140, 8024,  791, 2399,  122,
        2259, 8024,  671, 2458, 1993, 8024, 6432, 6413, 3198, 6230, 2533, 1624,
        2094, 4563, 8024, 2175, 6230, 1168, 8024, 1495, 1644, 3683, 6772, 1196,
        4164, 8024, 5445,  684, 8024,  671, 4684, 6963, 3300, 4157, 1355, 4173,
        8024, 6435, 7309, 8038, 2207, 1036, 2793, 3425,  860, 4142, 4563, 6421,
        2582, 3416, 3780, 3126, 3362, 1962,  511, 3780, 4545, 4638, 6413, 7674,
        1044, 1377, 3418, 2945, 2111, 2094, 4568, 4307, 5314,  750, 2190, 4568,
        3867, 4142, 5790, 8024,  738, 1377, 2229, 6956, 1103, 3819, 2772, 3221,
        2229, 6956, 1613, 5790, 8024, 2793, 3425,  860, 1079,  738, 1377, 3800,
        2198, 2190, 4568, 5790, 4289, 8024, 4545, 3126, 6963, 3221,  679, 7231,
        4638, 8024,  738, 1377,  809, 3418, 2945, 2111, 2094, 4638, 2658, 1105,
        5314, 4157, 7252, 4

epoch-1:   0%|          | 25/9000 [00:02<10:56, 13.68it/s, loss=3.5631337]

input_ids =  tensor([ 103, 1159, 4495, 1036, 1391, 1959, 5106,  677, 4125, 1408, 8043, 2769,
        2157, 2140, 2140, 2247,  754, 3193,  772, 2111, 2094, 8024,  671, 2458,
        1993, 2769, 2218, 2523, 2857, 2552, 2111, 2094, 2658, 1105, 8024, 3297,
        6818, 2111, 2094, 1139, 4385,  749, 6821,  763,  671, 5143, 1154, 4638,
        2658, 1105, 8024, 2769,  738, 2523, 2857, 2552, 8024, 2154, 2586, 2111,
        2094, 1359, 2533, 3291,  698, 7028, 8024, 6435, 7309, 1159, 4495, 1036,
        1391, 1959, 5106,  677, 4125, 1408, 1959, 5106, 1585, 1075, 4638, 2140,
        2140, 6206, 1762,  697, 3613, 1585, 1959,  722, 7313, 1585, 4157, 3946,
        2458, 3717, 8024, 3924, 1217, 4157, 1959,  845, 5868, 5843, 5131, 8024,
        3921, 1394, 1585, 1075, 4638, 2140, 2140, 8024, 1968, 1968,  738, 6206,
        6912, 1048, 1391,  763, 3211,  677, 4125, 4638, 7608, 4289,  511, 8024,
        1044, 1921, 2595, 5513, 4655, 6783, 2228, 5052, 6825, 2970, 1905, 3453,
        7349, 2792, 5636, 5

epoch-1:   0%|          | 29/9000 [00:02<10:16, 14.56it/s, loss=3.3587580]

input_ids =  tensor([ 103, 3466, 3389, 2544, 7030, 1039, 5162, 3633, 2382, 2869, 3688, 4568,
        1178, 1045, 3300, 3766, 3300, 4500, 8043,  798, 5301, 3466, 3389, 2544,
        7030, 1039, 5162, 3633, 2382, 8024, 2869, 3688, 4568, 1178, 1045, 3300,
        4500, 1408, 8043, 2111, 2094,  736, 2259,  749, 8024, 2869, 3688, 4568,
        3300,  671,  702, 3299, 8024, 3300,  784,  720, 1962, 4638, 2456, 6379,
        1408, 8043, 2769, 2682,  791, 1921, 2372,  800, 1343, 3092, 7490, 8024,
        1178, 2398, 1928, 4197, 1400,  743,  702, 2384, 2094, 2869, 3688, 4568,
        1377, 5543,  833, 6656, 2111, 2094, 4193, 5991, 6656, 5125, 4868, 6814,
        2428, 5165, 2476, 3300, 1068, 5143, 8013, 6820, 3221, 2372, 2111, 2094,
        1343,  976,  702, 2552, 4415, 1486, 6418, 8024, 5543, 6375, 2111, 2094,
        3138, 2458, 2552, 2796, 8013, 1963, 3362, 2111, 2094, 4568, 4307,  679,
         698, 7028, 8024, 1377,  809,  985, 6407, 4708, 1914, 1068, 2552, 6225,
        3800, 2111, 2094, 8

epoch-1:   0%|          | 33/9000 [00:02<09:39, 15.47it/s, loss=2.6727870]

input_ids =  tensor([ 103, 1063, 2259, 2207, 2111, 3801, 4157,  679, 6858, 2582,  720, 3780,
        4545, 8043, 1139, 4495, 2218, 3300, 3837, 3801,  872, 1962,  117, 2405,
        1036, 3801, 5593,  679, 6858, 3211, 2471, 6629, 3837, 3801, 5023, 4568,
        4307,  119, 1377,  809, 1044, 7023, 1357,  924, 2127, 4545, 3791, 8024,
        2229, 6956, 4017, 2834, 4495, 5162, 4017, 4706, 3890, 8024, 1968, 1968,
        1377,  809, 1762, 2140, 2140, 4638, 1079, 4706, 6235, 1905,  976, 2902,
        3040, 8024, 4500, 2797, 2900, 5592, 2518,  678, 2902, 8024, 3123, 2458,
        8024, 1086, 2902, 8024, 6825, 5330, 1282, 3613, 8024, 4197, 1400, 8024,
         794, 1079, 4706, 6235, 2518, 7965, 2629, 8024, 4507,  677, 2518,  678,
        2902, 3040, 8024,  738, 3221, 1282, 3613, 6821,  702, 3126, 3362, 6820,
         679, 7231,  511, 8024, 1728, 2193, 5636, 2207, 1036, 3801, 6887, 1843,
        1853, 4638, 1728, 5162, 1914, 4905, 1914, 3416, 8024, 3780, 4545, 1184,
        2553, 7557,  749, 6

epoch-1:   0%|          | 35/9000 [00:02<09:47, 15.27it/s, loss=2.9400237]

input_ids =  tensor([ 103, 1278, 7368, 2582,  720, 3780, 3255, 1213,  856,  678, 6821,  702,
        4567,  136, 2207, 2111, 2094, 2347, 5307,  758, 2259,  749, 8024, 2697,
        6230, 3300,  763, 2460, 2382, 8024,  679, 1008, 3633, 2382, 4638, 2111,
        2094, 8024, 2600, 3221, 4717, 6230, 8024, 1353, 2418, 6826, 7162, 8024,
        3221,  679, 3221, 3255, 1213,  856,  678, 8043, 4385, 1762, 2682, 4761,
        6887, 8024, 1278, 7368, 2582,  720, 3780, 3255, 1213,  856,  678, 6821,
         702, 4567,  136, 3301, 1351, 8024, 3418, 2945,  872, 4638, 1360, 6835,
        8024, 1036, 4997, 3255, 1213,  856,  678, 7444, 6206, 3300,  757, 4685,
        6981, 1394, 6817, 4500, 1278, 2110,  510, 4852,  833,  510, 3136, 2193,
        1469, 5466,  689, 4294, 6378, 5023, 2974, 3177, 8024, 2902, 2399, 7977,
        1920, 2207, 1469, 3255, 1213,  856,  678, 4638,  698, 7028, 4923, 2428,
        2190, 2642, 5442, 2141, 3177, 4294, 6378, 8024,  886, 1071, 6809, 1168,
        2226, 1377, 5543, 7

epoch-1:   0%|          | 36/9000 [00:02<11:30, 12.99it/s, loss=2.9400237]


input_ids =  tensor([ 103,  122, 2259,  130,  702, 3299, 1920,  912, 3198, 1526, 3221, 6768,
        2544,  912, 4908, 2582,  720, 1215, 8043, 6435, 7309,  117, 2207, 1957,
         122, 2259,  130,  702, 3299,  117, 1762, 1920,  912, 3198, 1962, 1526,
         117, 3221, 6768, 2544,  912, 4908,  117, 6421, 2582, 3416, 3780, 4545,
        1450,  136, 1961, 2398, 2382, 1391,  763, 5922, 5831, 7313, 7392, 1391,
        8024, 6820, 3300, 6007,  510, 3717, 7659,  510, 3739,  510, 7672, 1928,
         510, 4156, 2372, 7824,  510, 2207, 5101, 3739, 8024,  671, 5663, 7650,
        7608,  738, 2218, 6821,  763, 8024, 7370,  749, 1914, 1600, 3717, 8024,
        6820, 3300, 1166, 4638, 3175, 3791, 1408, 8043, 6468, 6468, 8013,  872,
        1962, 8024,  671, 2259, 1914, 2140, 2140,  912, 4908, 3680, 3613, 1920,
         912, 6963, 1526, 8024, 1377, 5543, 3221, 7270, 4574, 4555,  749,  117,
         679,  833, 6432, 6413,  117, 3680, 3613, 2861, 5107, 5107, 2218, 1526,
         117, 1377, 5543, 3

KeyboardInterrupt: 