In [1]:
import torch
import pickle
from dataclasses import dataclass
from torch.utils.data import Dataset
from transformers import BartForConditionalGeneration, BartTokenizer, TrainingArguments, WhisperProcessor, Trainer

bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
processor = WhisperProcessor.from_pretrained("openai/whisper-large-v3", language='en')
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-large-cnn", language='en')

2024-08-24 04:25:12.810350: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-08-24 04:25:12.810473: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-08-24 04:25:12.973837: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


config.json:   0%|          | 0.00/1.63k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.02G [00:00<?, ?B/s]

  return self.fget.__get__(instance, owner)()


preprocessor_config.json:   0%|          | 0.00/340 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/283k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.48M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/494k [00:00<?, ?B/s]

normalizer.json:   0%|          | 0.00/52.7k [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/34.6k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/2.07k [00:00<?, ?B/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.58k [00:00<?, ?B/s]

In [2]:
def total_params(model):
    return sum(p.numel() for p in model.parameters())

print(f'Memory used by model: {round(bart_model.get_memory_footprint()/1024/1024/1024, 2)} GB')
print(f'total number of parameters is {total_params(bart_model)}')

Memory used by model: 1.51 GB
total number of parameters is 406291456


In [3]:
with open('/kaggle/input/whisper-output-notebook/sequences.pkl', mode='rb') as f:
    sequences = pickle.load(f)
with open('/kaggle/input/whisper-output-notebook/labels.pkl', mode='rb') as f:
    labels = pickle.load(f)

In [4]:
remove_values = {50364, 50257, 50258, 50259, 50260, 50261, 50262, 50263, 50264, 50265, 50266, 50267, 50268, 50269, 50270, 50271, 50272, 50273, 50274, 50275, 50276,
                 50277, 50278, 50279, 50280, 50281, 50282, 50283, 50284, 50285, 50286, 50287, 50288, 50289, 50290, 50291, 50292, 50293, 50294, 50295, 50296, 50297, 
                 50299, 50300, 50301, 50302, 50303, 50304, 50305, 50306, 50307, 50308, 50309, 50310, 50311, 50312, 50313, 50314, 50315, 50316, 50317, 50318, 50319,
                 50321, 50322, 50323, 50324, 50325, 50326, 50327, 50328, 50329, 50330, 50331, 50332, 50333, 50334, 50335, 50336, 50337, 50338, 50339, 50340, 50341,
                 50343, 50344, 50345, 50346, 50347, 50348, 50349, 50350, 50351, 50352, 50353, 50354, 50355, 50356, 50357, 50358, 50359, 50360, 50361, 50362, 50363,
                 50298, 50320, 50342}

class training_dataset(Dataset) :
    def __init__(self, sequences, labels, remove_list) :
        super().__init__()
        self.sequences = sequences
        self.labels = labels
        self.remove_list = remove_list

    def __len__(self) :
        return self.labels.__len__()

    def __getitem__(self, idx) :
        output = {}
        data = self.sequences[idx]
        seq = [item for item in data if int(item) not in self.remove_list]
        output['sequences'] = processor.tokenizer.decode(seq)
        
        labels = self.labels[idx][0]
        labels = [item for item in labels if int(item) not in self.remove_list]
        output['labels'] = processor.tokenizer.decode(labels)
        
        return output

dataset = training_dataset(sequences, labels, remove_values)

In [5]:
@dataclass
class data_collator :
        
    def __call__(self, features: dict):
        
        src_sentence = [feature['sequences'] for feature in features]
        tgt_sentence = [feature['labels'] for feature in features]
        scr = bart_tokenizer(src_sentence, padding='longest', truncation=True, max_length=100, return_token_type_ids=False, return_tensors='pt')
        tgt = bart_tokenizer(tgt_sentence, padding='longest', truncation=True, max_length=100, return_token_type_ids=False, return_tensors='pt')
        output = {}
        output['input_ids'] = scr['input_ids']
        output['labels'] = tgt['input_ids']
        src_sentence = tgt_sentence = None
        return output
    
collator = data_collator()

In [7]:
training_args = TrainingArguments(
                    output_dir="/kaggle/working/",
                    report_to = 'none',
                    lr_scheduler_type='cosine',
                    per_device_train_batch_size=16,
                    learning_rate=3e-5,
                    weight_decay = 2e-4,
                    warmup_steps=40,
                    num_train_epochs=1,
                    max_steps=50,
                    fp16=True,
                    save_strategy = 'epoch',
                    logging_steps=16,
                    remove_unused_columns=False,
                )
trainer = Trainer(model = bart_model, args=training_args, data_collator = collator, train_dataset = dataset)
trainer.train()

max_steps is given, it will override any value given in num_train_epochs


Step,Training Loss
16,8.6733
32,6.6453
48,4.6667


Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}
Non-default generation parameters: {'early_stopping': True, 'num_beams': 4, 'no_repeat_ngram_size': 3, 'forced_bos_token_id': 0, 'forced_eos_token_id': 2}


NameError: name 'state_dict' is not defined