In [40]:
import torch
import pandas as pd 
from datasets import load_dataset
from transformers import AutoTokenizer, BertGenerationEncoder, BertGenerationDecoder, EncoderDecoderModel
import torch.nn as nn
import logging 

In [6]:
dataset = load_dataset("/home/featurize/data",data_files={'train':"filtered_joke_128.csv"})

Using custom data configuration .-55ca6b61aad76f43


Downloading and preparing dataset csv/. to C:/Users/Xiang/.cache/huggingface/datasets/csv/.-55ca6b61aad76f43/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

  return pd.read_csv(xopen(filepath_or_buffer, "rb", use_auth_token=use_auth_token), **kwargs)


Dataset csv downloaded and prepared to C:/Users/Xiang/.cache/huggingface/datasets/csv/.-55ca6b61aad76f43/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

In [16]:
encoder_max_length = 128
decoder_max_length = 128
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def process_data_to_model_inputs(batch):
  # tokenize the inputs and labels
  inputs = tokenizer(batch["title"], padding="max_length", truncation=True, max_length=encoder_max_length)
  outputs = tokenizer(batch["selftext"], padding="max_length", truncation=True, max_length=decoder_max_length)

  batch["input_ids"] = inputs.input_ids
  batch["attention_mask"] = inputs.attention_mask
  # batch["decoder_input_ids"] = outputs.input_ids
  batch["decoder_attention_mask"] = outputs.attention_mask
  batch["labels"] = outputs.input_ids.copy()

  # because BERT automatically shifts the labels, the labels correspond exactly to `decoder_input_ids`. 
  # We have to make sure that the PAD token is ignored
  batch["labels"] = [[-100 if token == tokenizer.pad_token_id else token for token in labels] for labels in batch["labels"]]

  return batch

In [18]:
batchsize = 128

# train_data = dataset['train'].select(range(32))
train_data = dataset['train'].map(
    process_data_to_model_inputs, 
    batched=True, 
    batch_size= batchsize, 
    remove_columns=["selftext", "title", "Unnamed: 0"]
)

  0%|          | 0/8 [00:00<?, ?ba/s]

In [20]:
train_data.set_format(
    type="torch", columns=["input_ids", "attention_mask", "decoder_input_ids", "decoder_attention_mask", "labels"],
)

In [22]:
bert2bert = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertLMHeadModel: ['cls.seq_relationship.weight', 'cls.seq_relatio

In [23]:
# bert2bert.save_pretrained("bert2bert")
# bert2bert = EncoderDecoderModel.from_pretrained("bert2bert")
bert2bert.config.decoder_start_token_id = tokenizer.cls_token_id
bert2bert.config.eos_token_id = tokenizer.sep_token_id
bert2bert.config.pad_token_id = tokenizer.pad_token_id
bert2bert.config.vocab_size = bert2bert.config.encoder.vocab_size

bert2bert.config.max_length = 128
bert2bert.config.min_length = 1
bert2bert.config.no_repeat_ngram_size = 3
bert2bert.config.early_stopping = True
bert2bert.config.length_penalty = 2.0
bert2bert.config.num_beams = 4

In [27]:
from torch.utils.data import DataLoader
from torch.optim import AdamW

train_loader = DataLoader(train_data, shuffle= True, batch_size = 2)
optimizer = AdamW(bert2bert.parameters(), lr = 1e-5)

In [38]:
def get_log(file_name):
    logger = logging.getLogger('train')  # 设定logger的名字
    logger.setLevel(logging.INFO)  # 设定logger得等级

    ch = logging.StreamHandler()  # 输出流的hander，用与设定logger的各种信息
    ch.setLevel(logging.INFO)  # 设定输出hander的level

    fh = logging.FileHandler(file_name, mode='a')  # 文件流的hander，输出得文件名称，以及mode设置为覆盖模式
    fh.setLevel(logging.INFO)  # 设定文件hander得lever



    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    ch.setFormatter(formatter)  # 两个hander设置个是，输出得信息包括，时间，信息得等级，以及message
    fh.setFormatter(formatter)
    logger.addHandler(fh)  # 将两个hander添加到我们声明的logger中去
    logger.addHandler(ch)
    return logger


In [41]:
logger = get_log('log.txt')

In [42]:
epoch = 1000
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
ACC_STEP = 4

bert2bert.to(device)

for i in range(epoch):
    bert2bert.train()
    total_loss = 0
    for idx, X in enumerate(train_loader):
        # X is a large batch
        batch = {k: v.to(device) for k, v in X.items()}
        outputs = bert2bert(**batch)

        loss = outputs.loss / ACC_STEP
        loss.backward()
        total_loss += loss.item()

        if (idx+1) % ACC_STEP or (idx + 1 == len(train_loader)):
            logger.info(f'Batch {idx}, Loss: {loss.item()}')
            optimizer.step()
    
    bert2bert.eval()
    with torch.no_grad():
        output = bert2bert.generate(tokenizer('I ate pizza this evening.', return_tensors = 'pt').input_ids.cuda())
        logger.info(tokenizer.batch_decode(output.cpu()))
        
    bert2bert.save_pretrained("bert2bert")

2023-02-16 21:00:36,658 - INFO - Batch 0, Loss: 1.2841315269470215
2023-02-16 21:00:36,850 - INFO - Batch 1, Loss: 1.194342017173767
2023-02-16 21:00:37,035 - INFO - Batch 2, Loss: 1.1022460460662842
2023-02-16 21:00:37,333 - INFO - Batch 4, Loss: 0.7442519664764404
2023-02-16 21:00:37,521 - INFO - Batch 5, Loss: 1.6024678945541382
2023-02-16 21:00:37,705 - INFO - Batch 6, Loss: 0.703491747379303
2023-02-16 21:00:38,001 - INFO - Batch 8, Loss: 0.468960702419281
2023-02-16 21:00:38,183 - INFO - Batch 9, Loss: 0.8585889935493469
2023-02-16 21:00:38,368 - INFO - Batch 10, Loss: 0.447396457195282
2023-02-16 21:00:38,662 - INFO - Batch 12, Loss: 0.8794702887535095
2023-02-16 21:00:38,849 - INFO - Batch 13, Loss: 0.7277068495750427
2023-02-16 21:00:39,034 - INFO - Batch 14, Loss: 0.6288974285125732
2023-02-16 21:00:39,219 - INFO - Batch 15, Loss: 0.4530305564403534
2023-02-16 21:00:42,659 - INFO - ['[CLS] bat on on on off off off of of one one one has has had had had got got got get get get 