In [15]:
import os
import sys
import pandas as pd
import random
from datetime import datetime
from collections import OrderedDict

import torch
import torch.nn as nn
import torch.optim as optim

import transformers
from transformers import BartTokenizer, BartForConditionalGeneration

torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'

SAVE_DIR = ""  #= "#/home/alta/summary/pm574/podcast_sum0/lib/trained_models"
MODEL_NAME = "bart-baseline"

In [33]:
def train():
    # Model & Optimizer
    bart_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-cnn')
    bart = BartForConditionalGeneration.from_pretrained('facebook/bart-large-cnn')

    optimizer = optim.Adam(filter(lambda p: p.requires_grad, bart.parameters()), lr=0.001,betas=(0.9,0.999),eps=1e-08,weight_decay=0)
    optimizer.zero_grad()

    bart_config = bart.model.config
    print(bart)
    print(bart_config)
    if torch_device == 'cuda': bart.cuda()
    print("#parameters:", sum(p.numel() for p in bart.parameters() if p.requires_grad))
    bart.train()

    # Data
    #podcasts = load_podcast_data(sets=-1) # -1 means set0,..,set9 (excluding 10)
    load_podcasts = pd.read_pickle("bart_train_regex.pkl")
    train_podcasts = load_podcasts[:5000]
    
    #batcher = BartBatcher(bart_tokenizer, bart.model.config, podcasts, torch_device)

    # Validation
    val_podcasts = load_podcasts[5000:5500]
    #val_batcher = BartBatcher(bart_tokenizer, bart.model.config, val_podcasts, torch_device)

    # Criterion
    criterion = nn.CrossEntropyLoss(reduction='none') # This criterion combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

    training_step  = 0
    batch_size     = 1
    gradient_accum = 2
    valid_step     = 20000 # every a few hours on lapaz machine (1GPU - 1080Ti)
    total_step     = 20000 * 1000
    best_val_loss  = 99999999
    random_seed    = 777
    stop_counter   = 0

    print("batch_size:", batch_size)
    print("training_step:", training_step)
    print("gradient_accum:", gradient_accum)
    print("total_step:", total_step)
    print("valid_step:", valid_step)
    print("random_seed:", random_seed)

    # Randomness
    random.seed(random_seed)
    torch.manual_seed(random_seed)

    # shuffle data
    train_podcasts.sample(frac=1)

    if torch.cuda.device_count() > 1:
        print("Multiple GPUs: {}".format(torch.cuda.device_count()))
        bart = nn.DataParallel(bart)

    while training_step < total_step:
        # get a batch
        #input_ids, attention_mask, target_ids, target_attention_mask = batcher.get_a_batch(batch_size=batch_size)
        encoded_inputs = bart_tokenizer.batch_encode_plus(train_podcasts['transcript'].tolist(), add_special_tokens=True, pad_to_max_length=True,
            max_length= bart_config.max_position_embeddings, return_tensors='pt')
        print(type(train_podcasts['transcript'].tolist()))
        #input_ids = bart_tokenizer.tokenize(train_podcasts)
        
        #decoder_input_ids = bart_tokenizer.encode(input_ids, return_tensors='pt')
        print(type(train_podcasts['episode_description'].tolist()))
        
        batch_encoded_targets = bart_tokenizer.encode_plus(train_podcasts['episode_description'].tolist(), add_special_tokens=True, pad_to_max_length=True,
            max_length= bart_config.max_position_embeddings, return_tensors='pt')
        
        input_ids = encoded_inputs['input_ids']
        #print(type(input_ids))
        attention_mask = encoded_inputs['attention_mask']
        decoder_input_ids = batch_encoded_targets['input_ids']
        decoder_attention_mask = batch_encoded_targets['attention_mask']
        
        shifted_target_ids = torch.zeros(decoder_input_ids.shape, dtype=decoder_input_ids.dtype)
        shifted_target_attention_mask = torch.zeros(decoder_attention_mask.shape, dtype=torch.float)
        shifted_target_ids[:,:-1] = decoder_input_ids.clone().detach()[:,1:]
        shifted_target_attention_mask[:,:-1] = decoder_attention_mask.clone().detach()[:,1:]
        
        
        #shifted_target_ids, shifted_target_attention_mask = batcher.shifted_target_left(target_ids, target_attention_mask)
        # BART forward
        x = bart(
            input_ids=input_ids,
            attention_mask = attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
        )
        # x[0] # decoder output
        # x[1] # encoder output
        lm_logits = x[0]

        loss = criterion(lm_logits.view(-1, bart_config.vocab_size), shifted_target_ids.view(-1))
        shifted_target_attention_mask = shifted_target_attention_mask.view(-1)
        loss = (loss * shifted_target_attention_mask).sum() / shifted_target_attention_mask.sum()
        loss.backward()

        if training_step % gradient_accum == 0:
            adjust_lr(optimizer, training_step)
            optimizer.step()
            optimizer.zero_grad()

        if training_step % 1 == 0:
            print("[{}] step {}/{}: loss = {:.5f}".format(str(datetime.now()), training_step, total_step, loss))
            sys.stdout.flush()

        # if training_step % 5 == 0:
        #     tgt_len = target_attention_mask[0].sum().item()
        #     print("REF: {}".format(bart_tokenizer.decode(shifted_target_ids[0,:tgt_len].cpu().numpy())))
        #     print("HYP: {}".format(bart_tokenizer.decode(torch.argmax(lm_logits[0,:tgt_len].cpu(), dim=-1).numpy())))

        if training_step % valid_step == 0 and training_step > 5:
            bart.eval()
            with torch.no_grad():
                valid_loss = validation(bart, bart_config, val_podcasts, 1, 1)
            print("Valid Loss = {:.5f}".format(valid_loss))
            bart.train()
            if valid_loss < best_val_loss:
                stop_counter = 0
                best_val_loss = valid_loss
                print("Model improved".format(stop_counter))
            else:
                stop_counter += 1
                print("Model not improved #{}".format(stop_counter))
                if stop_counter == 3:
                    print("Stop training!")
                    return

            state = {
                'training_step': training_step,
                'model': bart.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_val_loss': best_val_loss
            }
            savepath = "{}/{}-step{}.pt".format(SAVE_DIR, MODEL_NAME, training_step)
            torch.save(state, savepath)
            print("Saved at {}".format(savepath))

        training_step += 1
    print("Finish Training")


In [3]:

def adjust_lr(optimizer, step, warmup=10000):
    """to adjust the learning rate"""
    step = step + 1 # plus 1 to avoid ZeroDivisionError
    lr = 2e-3 * min(step**(-0.5), step*(warmup**(-1.5))) # 0.5 for effecetive batch_size
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return


In [25]:
def validation(bart, bart_config, val_podcasts, epoch_counter, batch_size):
    print("start validating")
    criterion = nn.CrossEntropyLoss(reduction='none')
    sum_loss = 0
    sum_token = 0
    while epoch_counter < 1:
        #input_ids, attention_mask, target_ids, target_attention_mask = val_batcher.get_a_batch(batch_size=batch_size)
        #shifted_target_ids, shifted_target_attention_mask = val_batcher.shifted_target_left(target_ids, target_attention_mask)
        
        
        encoded_inputs = bart_tokenizer.batch_encode_plus(val_podcasts['transcript'].tolist(), add_special_tokens=True, pad_to_max_length=True,
            max_length= bart_config.max_position_embeddings, return_tensors='pt')
        
        #input_ids = bart_tokenizer.tokenize(train_podcasts)
        input_ids = encoded_inputs['input_ids']
        attention_mask = encoded_inputs['attention_mask']
        decoder_input_ids = bart_tokenizer.encode(input_ids, return_tensors='pt')
        
        batch_encoded_targets = bart_tokenizer.encode(val_podcasts['episode_description'].tolist(),
            add_special_tokens=True, pad_to_max_length=True,
            max_length=144, return_tensors='pt')

        target_ids = batch_encoded_targets['input_ids']
        target_attention_mask = batch_encoded_targets['attention_mask']
        
        shifted_target_ids = torch.zeros(decoder_input_ids.shape, dtype=decoder_input_ids.dtype)
        shifted_target_attention_mask = torch.zeros(decoder_attention_mask.shape, dtype=torch.float)
        shifted_target_ids[:,:-1] = decoder_input_ids.clone().detach()[:,1:]
        shifted_target_attention_mask[:,:-1] = decoder_attention_mask.clone().detach()[:,1:]
        
        x = bart(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=target_ids,
            decoder_attention_mask=target_attention_mask,
        )
        lm_logits = x[0]
        loss = criterion(lm_logits.view(-1, bart_config.vocab_size), shifted_target_ids.view(-1))
        shifted_target_attention_mask = shifted_target_attention_mask.view(-1)
        sum_loss += (loss * shifted_target_attention_mask).sum().item()
        sum_token += shifted_target_attention_mask.sum().item()
        print("#", end="")
        sys.stdout.flush()
    print()
    epoch_counter = 0
    print("finish validating")

    return sum_loss / sum_token

In [None]:
if __name__ == "__main__":
    train()

BartForConditionalGeneration(
  (model): BartModel(
    (shared): Embedding(50264, 1024, padding_idx=1)
    (encoder): BartEncoder(
      (embed_tokens): Embedding(50264, 1024, padding_idx=1)
      (embed_positions): BartLearnedPositionalEmbedding(1026, 1024)
      (layers): ModuleList(
        (0): BartEncoderLayer(
          (self_attn): BartAttention(
            (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
            (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
          )
          (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
   

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


batch_size: 1
training_step: 0
gradient_accum: 2
total_step: 20000000
valid_step: 20000
random_seed: 777




In [9]:
load_podcasts = pd.read_pickle("bart_train_regex.pkl")
train_podcasts = load_podcasts[:5000]

lst = (train_podcasts['transcript'].tolist())
ls2 = (train_podcasts['episode_description'].tolist())
for i in range(len(lst)):
    if type(lst[1]) != type(ls2[2]):
        print("Yes")
    #print(type(lst[1]), type(ls2[2]) )
#train_podcasts['episode_description'].tolist()



In [10]:
print(len(lst), len(ls2))

5000 5000


In [16]:
print(torch.version)
print(transformers.version)

<module 'torch.version' from 'C:\\Users\\anany\\AppData\\Local\\Programs\\Python\\Python39\\lib\\site-packages\\torch\\version.py'>


AttributeError: module transformers has no attribute version