In [1]:
from transformers import (
    GPT2PreTrainedModel, 
    GPT2Config, 
    GPT2Model, 
    GPT2TokenizerFast, 
    DataCollatorForLanguageModeling
)
from transformers import Trainer, TrainingArguments
from transformers.models.gpt2.modeling_gpt2 import GPT2Block
from transformers.modeling_outputs import (
    BaseModelOutputWithPastAndCrossAttentions,
    CausalLMOutputWithCrossAttentions
)
from tokenizers import Tokenizer
from torch import nn
from torch.utils.data import Dataset
from pathlib import Path
import torch
from packaging import version
from tqdm import tqdm
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import wandb
import math
import os

os.environ["CUDA_LAUNCH_BLOCKING"] = "0"

In [2]:
torch.cuda.set_device(0)

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmacosta[0m (use `wandb login --relogin` to force relogin)


True

In [4]:
TOKENIZER_SAVEDIR = Path('/home/macosta/ttmp/primus-data/primus-semantic/semantic-tokenizer-v2/')
LM_MODEL_SAVEDIR = Path('/home/macosta/ttmp/primus-models/gpt2-lm-semantic-rhythm-v2/')
Path(LM_MODEL_SAVEDIR).mkdir(exist_ok=True)
TXT_FILES = Path('/home/macosta/ttmp/primus-data/primus-semantic/semantic-cleaned-v2')

In [5]:
def separate_wt_from_rhythm(tensor, rhythmic_bits, device):
        tensor.to(device)
        shift_amount = torch.ones(tensor.shape, dtype=torch.int64, device=device) * rhythmic_bits
        mask = torch.ones(tensor.shape, dtype=torch.int64, device=device) * (2 ** rhythmic_bits - 1)
        wt = torch.bitwise_right_shift(tensor, shift_amount)
#         print("WT:", torch.min(wt), torch.max(wt))
        rhythms = torch.bitwise_and(tensor, mask)
#         print("RHYTHM:", torch.min(rhythms), torch.max(rhythms))
        return wt, rhythms

In [6]:
class CustomGPT2Model(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = ["attn.masked_bias"]

    def __init__(self, config):
        super().__init__(config)

        self.embed_dim = config.hidden_size
        ''' NEW '''
        self.rhythmic_granularity = config.rhythmic_granularity
        self.rhythmic_bits = math.ceil(math.log(self.rhythmic_granularity, 2))
        self.vocab_size = config.vocab_size

        self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
        self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
        ''' NEW '''
        self.re = nn.Embedding(self.rhythmic_granularity, self.embed_dim)

        self.drop = nn.Dropout(config.embd_pdrop)
        self.h = nn.ModuleList([GPT2Block(config, layer_idx=i) for i in range(config.num_hidden_layers)])
        self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

        # Model parallel
        self.model_parallel = False
        self.device_map = None
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def parallelize(self, device_map=None):
        # Check validity of device_map
        self.device_map = (
            get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
        )
        assert_device_map(self.device_map, len(self.h))
        self.model_parallel = True
        self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
        self.last_device = "cuda:" + str(max(self.device_map.keys()))
        self.wte = self.wte.to(self.first_device)
        self.wpe = self.wpe.to(self.first_device)
        ''' NEW '''
        self.re = self.re.to(self.first_device)
    
        # Load onto devices
        for k, v in self.device_map.items():
            for block in v:
                cuda_device = "cuda:" + str(k)
                self.h[block] = self.h[block].to(cuda_device)
        # ln_f to last
        self.ln_f = self.ln_f.to(self.last_device)

    def deparallelize(self):
        self.model_parallel = False
        self.device_map = None
        self.first_device = "cpu"
        self.last_device = "cpu"

        self.wte = self.wte.to("cpu")
        self.wpe = self.wpe.to("cpu")
        ''' NEW '''
        self.re = self.re.to("cpu")
        
        for index in range(len(self.h)):
            self.h[index] = self.h[index].to("cpu")
        self.ln_f = self.ln_f.to("cpu")
        torch.cuda.empty_cache()

    def get_input_embeddings(self):
        return self.wte

    def set_input_embeddings(self, new_embeddings):
        self.wte = new_embeddings

    def _prune_heads(self, heads_to_prune):
        """
        Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
        """
        for layer, heads in heads_to_prune.items():
            self.h[layer].attn.prune_heads(heads)

    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
        elif input_ids is not None:
            input_shape = input_ids.size()
            input_ids = input_ids.view(-1, input_shape[-1])
            batch_size = input_ids.shape[0]
        elif inputs_embeds is not None:
            input_shape = inputs_embeds.size()[:-1]
            batch_size = inputs_embeds.shape[0]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        device = input_ids.device if input_ids is not None else inputs_embeds.device

        if token_type_ids is not None:
            token_type_ids = token_type_ids.view(-1, input_shape[-1])
        if position_ids is not None:
            position_ids = position_ids.view(-1, input_shape[-1])

        if past_key_values is None:
            past_length = 0
            past_key_values = tuple([None] * len(self.h))
        else:
            past_length = past_key_values[0][0].size(-2)
        if position_ids is None:
            position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])

        # GPT2Attention mask.
        if attention_mask is not None:
            if batch_size <= 0:
                raise ValueError("batch_size has to be defined and > 0")
            attention_mask = attention_mask.view(batch_size, -1)
            # We create a 3D attention mask from a 2D tensor mask.
            # Sizes are [batch_size, 1, 1, to_seq_length]
            # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
            # this attention mask is more simple than the triangular masking of causal attention
            # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
            attention_mask = attention_mask[:, None, None, :]

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_mask = attention_mask.to(dtype=self.dtype)  # fp16 compatibility
            attention_mask = (1.0 - attention_mask) * -10000.0

        # If a 2D or 3D attention mask is provided for the cross-attention
        # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
        if self.config.add_cross_attention and encoder_hidden_states is not None:
            encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
            encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
            if encoder_attention_mask is None:
                encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
            encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
        else:
            encoder_attention_mask = None

        # Prepare head mask if needed
        # 1.0 in head_mask indicate we keep the head
        # attention_probs has shape bsz x n_heads x N x N
        # head_mask has shape n_layer x batch x n_heads x N x N
        head_mask = self.get_head_mask(head_mask, self.config.n_layer)

        ''' OLD '''
#         if inputs_embeds is None:
#             inputs_embeds = self.wte(input_ids)
#         position_embeds = self.wpe(position_ids)
#         hidden_states = inputs_embeds + position_embeds
#         if token_type_ids is not None:
#             token_type_embeds = self.wte(token_type_ids)
#             hidden_states = hidden_states + token_type_embeds
        ''' NEW '''
        if input_ids is None:
            raise Exception("Need input ids")
        wt_tokens, rhythm_tokens = separate_wt_from_rhythm(input_ids, self.rhythmic_bits, device=device)
        wt_tokens = torch.clamp(wt_tokens, min=0, max=self.vocab_size-1)
        rhythm_tokens = torch.clamp(rhythm_tokens, min=0, max=self.rhythmic_granularity-1)
        inputs_embeds = self.wte(wt_tokens)
        position_embeds = self.wpe(position_ids)
        rhythmic_embeds = self.re(rhythm_tokens)
        hidden_states = inputs_embeds + position_embeds + rhythmic_embeds

        hidden_states = self.drop(hidden_states)

        output_shape = input_shape + (hidden_states.size(-1),)

        presents = () if use_cache else None
        all_self_attentions = () if output_attentions else None
        all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
        all_hidden_states = () if output_hidden_states else None
        for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

            # Model parallel
            if self.model_parallel:
                torch.cuda.set_device(hidden_states.device)
                # Ensure layer_past is on same device as hidden_states (might not be correct)
                if layer_past is not None:
                    layer_past = tuple(past_state.to(hidden_states.device) for past_state in layer_past)
                # Ensure that attention_mask is always on the same device as hidden_states
                if attention_mask is not None:
                    attention_mask = attention_mask.to(hidden_states.device)
                if isinstance(head_mask, torch.Tensor):
                    head_mask = head_mask.to(hidden_states.device)
            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)

            if self.gradient_checkpointing and self.training:

                if use_cache:
                    logger.warning(
                        "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                    )
                    use_cache = False

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, use_cache, output_attentions)

                    return custom_forward

                outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    hidden_states,
                    None,
                    attention_mask,
                    head_mask[i],
                    encoder_hidden_states,
                    encoder_attention_mask,
                )
            else:
                outputs = block(
                    hidden_states,
                    layer_past=layer_past,
                    attention_mask=attention_mask,
                    head_mask=head_mask[i],
                    encoder_hidden_states=encoder_hidden_states,
                    encoder_attention_mask=encoder_attention_mask,
                    use_cache=use_cache,
                    output_attentions=output_attentions,
                )

            hidden_states = outputs[0]
            if use_cache is True:
                presents = presents + (outputs[1],)

            if output_attentions:
                all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
                if self.config.add_cross_attention:
                    all_cross_attentions = all_cross_attentions + (outputs[3 if use_cache else 2],)

            # Model Parallel: If it's the last layer for that device, put things on the next device
            if self.model_parallel:
                for k, v in self.device_map.items():
                    if i == v[-1] and "cuda:" + str(k) != self.last_device:
                        hidden_states = hidden_states.to("cuda:" + str(k + 1))

        hidden_states = self.ln_f(hidden_states)

        hidden_states = hidden_states.view(*output_shape)
        # Add last hidden state
        if output_hidden_states:
            all_hidden_states = all_hidden_states + (hidden_states,)

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, presents, all_hidden_states, all_self_attentions, all_cross_attentions]
                if v is not None
            )

        return BaseModelOutputWithPastAndCrossAttentions(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
            cross_attentions=all_cross_attentions,
        )

In [7]:
class CustomGPT2LMHeadModel(GPT2PreTrainedModel):
    _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.transformer = CustomGPT2Model(config)
        ''' OLD '''
#         self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        ''' NEW '''
        self.lm_wt_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
        self.lm_rhythm_head = nn.Linear(config.n_embd, config.rhythmic_granularity, bias=False)
        self.rhythmic_granularity = config.rhythmic_granularity
        self.rhythmic_bits = math.ceil(math.log(self.rhythmic_granularity, 2))
        self.vocab_size = config.vocab_size

        # Model parallel
        self.model_parallel = False
        self.device_map = None

        # Initialize weights and apply final processing
        self.post_init()
    
    def parallelize(self, device_map=None):
        self.device_map = (
            get_device_map(len(self.transformer.h), range(torch.cuda.device_count()))
            if device_map is None
            else device_map
        )
        assert_device_map(self.device_map, len(self.transformer.h))
        self.transformer.parallelize(self.device_map)
        ''' OLD '''
#         self.lm_head = self.lm_head.to(self.transformer.first_device)
        ''' NEW '''
        self.lm_wt_head = self.lm_wt_head.to(self.transformer.first_device)
        self.lm_rhythm_head = self.lm_rhythm_head.to(self.transformer.first_device)
        
        self.model_parallel = True

    def deparallelize(self):
        self.transformer.deparallelize()
        self.transformer = self.transformer.to("cpu")
        ''' OLD '''
#         self.lm_head = self.lm_head.to("cpu")
        ''' NEW '''
        self.lm_wt_head = self.lm_wt_head.to("cpu")
        self.lm_rhythm_head = self.lm_rhythm_head.to("cpu")
    
        self.model_parallel = False
        torch.cuda.empty_cache()

    def get_output_embeddings(self):
        ''' OLD '''
#         return self.lm_head
        ''' NEW '''
        print("Sure hope this doesn't get called (get_output_embeddings)")
        return self.lm_wt_head

    def set_output_embeddings(self, new_embeddings):
        ''' OLD '''
        self.lm_head = new_embeddings
        ''' NEW '''
        print("Sure hope this doesn't get called (set_output_embeddings)")
        self.lm_wt_head = new_embeddings

    def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
        token_type_ids = kwargs.get("token_type_ids", None)
        # only last token for inputs_ids if past is defined in kwargs
        if past:
            input_ids = input_ids[:, -1].unsqueeze(-1)
            if token_type_ids is not None:
                token_type_ids = token_type_ids[:, -1].unsqueeze(-1)

        attention_mask = kwargs.get("attention_mask", None)
        position_ids = kwargs.get("position_ids", None)

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            if past:
                position_ids = position_ids[:, -1].unsqueeze(-1)
        else:
            position_ids = None
        return {
            "input_ids": input_ids,
            "past_key_values": past,
            "use_cache": kwargs.get("use_cache"),
            "position_ids": position_ids,
            "attention_mask": attention_mask,
            "token_type_ids": token_type_ids,
        }
    
    def forward(
        self,
        input_ids=None,
        past_key_values=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        encoder_hidden_states=None,
        encoder_attention_mask=None,
        labels=None,
        use_cache=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
            `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
            are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
        """
#         print(labels)
        
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = transformer_outputs[0]

        # Set device for model parallelism
        if self.model_parallel:
            torch.cuda.set_device(self.transformer.first_device)
            ''' OLD '''
#             hidden_states = hidden_states.to(self.lm_head.weight.device)
            ''' NEW '''
            hidden_states = hidden_states.to(self.lm_wt_head.weight.device)

        ''' OLD '''
#         lm_logits = self.lm_head(hidden_states)
        ''' NEW '''
        lm_wt_logits = self.lm_wt_head(hidden_states)
        lm_rhythm_logits = self.lm_rhythm_head(hidden_states)
        
        ''' NEW '''
        wt_labels, rhythm_labels = separate_wt_from_rhythm(labels, self.rhythmic_bits, device=self.lm_wt_head.weight.device)
        wt_labels = torch.clamp(wt_labels, min=0, max=self.vocab_size-1)
        rhythm_labels = torch.clamp(rhythm_labels, min=0, max=self.rhythmic_granularity-1)

        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            ''' OLD '''
#             shift_logits = lm_logits[..., :-1, :].contiguous()
#             shift_labels = labels[..., 1:].contiguous()
            ''' NEW '''
            shift_wt_logits = lm_wt_logits[..., :-1, :].contiguous()
            shift_rhythm_logits = lm_rhythm_logits[..., :-1, :].contiguous()
            shift_wt_labels = wt_labels[..., 1:].contiguous()
            shift_rhythm_labels = rhythm_labels[..., 1:].contiguous()
            
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            ''' OLD '''
#             loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            ''' NEW '''
            wt_loss = loss_fct(shift_wt_logits.view(-1, shift_wt_logits.size(-1)), shift_wt_labels.view(-1))
            rhythm_loss = loss_fct(shift_rhythm_logits.view(-1, shift_rhythm_logits.size(-1)), shift_rhythm_labels.view(-1))
            loss = wt_loss + rhythm_loss

        if not return_dict:
            raise Exception("hope we don't get here")
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithCrossAttentions(
            loss=loss,
#             logits=lm_logits,
            logits=[shift_wt_logits, shift_rhythm_logits],
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
            cross_attentions=transformer_outputs.cross_attentions,
        )

In [8]:
class CustomGPT2Config(GPT2Config):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.rhythmic_granularity = kwargs.get("rhythmic_granularity")

In [9]:
def get_tss(token):
    if token == 'whole' or token == 'double_whole' or token == 'quadruple_whole':
        return 32
    elif token == 'half':
        return 16
    elif token == 'quarter':
        return 8
    elif token == 'eighth':
        return 4
    elif token == 'sixteenth':
        return 2
    elif token == 'thirty_second':
        return 1
    return 0

In [10]:
def parse_time_sig(time_sig):
    time_sig = time_sig.split('-')[1]
    if time_sig == 'C':
        return 4, 4
    elif time_sig == 'C/':
        return 2, 2
    elif '/' in time_sig:
        top, bottom = time_sig.split('/')
        return int(top), int(bottom)
    else:
        return 4, 4

In [11]:
def time_sig_to_tss(time_sig):
    top, bottom = parse_time_sig(time_sig)
    return (32 // bottom) * top

In [12]:
def token_to_action(token):
    if token == 'barline':
        return ("RESET", 0)
    elif token in ['<s>', '</s>']:
        return ("CLEAR", 0)
    elif token == 'dot':
        return ("USE_LAST_DURATION", 0.5)
    elif token == 'dotdot':
        return ("USE_LAST_DURATION", 0.25)
    elif len(token) > 14 and token[:14] == 'timeSignature-':
        return ("SET_TIMESIG", time_sig_to_tss(token))
    else:
        return ("DECREMENT", get_tss(token))

In [13]:
def encode_rhythm(tokens):
    tts_left_arr = []
    tts_left_in_bar = 0
    tts_in_bar = 0
    last_duration = 0
    for i, token in enumerate(tokens):
        if i >= 2 and tokens[i - 2] == 'gracenote':
            action, data = ('DECREMENT', 0)
        else:
            action, data = token_to_action(token)
#         print(token, action, data)
        if action == 'RESET':
            tts_left_in_bar = tts_in_bar
        elif action == 'CLEAR':
            tts_left_in_bar = 0
            tts_in_bar = 0
        elif action == 'USE_LAST_DURATION':
            tts_left_in_bar -= int(last_duration * data)
        elif action == 'SET_TIMESIG':
            tts_in_bar = data
            tts_left_in_bar = tts_in_bar
        elif action == 'DECREMENT':
            tts_left_in_bar -= data
            last_duration = data
        tts_left_arr.append(tts_left_in_bar)
    return torch.tensor(tts_left_arr, dtype=torch.int64)

In [14]:
class CustomDataset(Dataset):
    def __init__(self, src_files, tokenizer, max_length, rhythmic_bits):
        self.examples = []
        pad_token = tokenizer.encode('<pad>')[0]
        for src_file in tqdm(src_files):
            words = src_file.read_text(encoding="utf-8")
            words = words.split()
            if 'sixty_fourth' in words:
                continue
            words = ['<s>'] + words + ['</s>']
            rhythm_tensor = torch.zeros(max_length, dtype=torch.int64)
            rhythm_encodings = encode_rhythm(words)
            rhythm_tensor[:len(rhythm_encodings)] = rhythm_encodings
            word_string = ' '.join(words)
            wt_tensor = tokenizer.encode(
                word_string, 
                return_tensors='pt', 
                max_length=max_length, 
                padding='max_length')[0].type(torch.int64)
            assert len(wt_tensor) == len(rhythm_tensor) == max_length
            encoded = wt_tensor * (2 ** rhythmic_bits) + rhythm_tensor
            example = torch.ones(max_length, dtype=torch.int64) * pad_token
            example[:len(encoded)] = encoded
            self.examples.append(example)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i):
        return self.examples[i]

In [15]:
def create_train_test_datasets(tokenizer, max_length, rhythmic_bits, fraction=1.0, test_size=0.1):
    src_files = list(Path(TXT_FILES).glob("**/*.semantic"))
    src_files = src_files[:int(len(src_files) * fraction)]
    split_index = int(len(src_files) * (1 - test_size))
    train_files = src_files[:split_index]
    test_files = src_files[split_index:]
    train_dataset = CustomDataset(train_files, tokenizer, max_length=max_length, rhythmic_bits=rhythmic_bits)
    test_dataset = CustomDataset(test_files, tokenizer, max_length=max_length, rhythmic_bits=rhythmic_bits)
    return train_dataset, test_dataset

In [16]:
temp_tokenizer = Tokenizer.from_file(str(TOKENIZER_SAVEDIR / 'tokenizer.json'))
tokenizer = GPT2TokenizerFast(tokenizer_object=temp_tokenizer, 
                                         unk_token='<unk>',
                                         pad_token='<pad>',
                                         bos_token='<s>',
                                         eos_token='</s>')

In [17]:
ACTUAL_VOCAB_SIZE = len(tokenizer.vocab)
MAX_LEN=256
RHYTHMIC_GRANULARITY = 64
RHYTHMIC_BITS = math.ceil(math.log(RHYTHMIC_GRANULARITY, 2))

In [18]:
print("ACTUAL VOCAB SIZE:", ACTUAL_VOCAB_SIZE)

ACTUAL VOCAB SIZE: 280


In [19]:
train_dataset, test_dataset = create_train_test_datasets(
    tokenizer, 
    MAX_LEN, 
    fraction=1, 
    test_size=0.05,
    rhythmic_bits=RHYTHMIC_BITS
)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 832/832 [00:00<00:00, 2370.13it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 44/44 [00:00<00:00, 2497.49it/s]


In [20]:
separate_wt_from_rhythm(train_dataset.__getitem__(120), rhythmic_bits=RHYTHMIC_BITS, device="cpu")

(tensor([  0,  33,  49,  41,  36, 109,   7,  10,   6,   4,  20,   6,   4,   9,
           6,   7,   4,  18,   6,   4,  23,   6,  10,   6,   7,  10,   6,   4,
          20,   6,   4,   9,   6,   7,   4,  18,   6,   4,  23,   6,   4,  18,
           6,   7,   4,  12,   5,  14,   4,  18,   8,   4,  12,   5,  14,   4,
           9,   8,   4,  30,   5,  14,   4,  12,   8,   7,   4,   9,   5,  14,
           4,  30,   8,   4,   9,   5,  14,   4,  12,   8,   4,   9,   5,  14,
           4,  18,   8,   7,   2,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
           1,   1,   1,   1,   1,   1,   1,   1,   1

In [21]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

In [22]:
config = CustomGPT2Config(
    vocab_size=ACTUAL_VOCAB_SIZE,
    bos_token_id=0,
    eos_token_id=2,
    n_positions=MAX_LEN,
    rhythmic_granularity=RHYTHMIC_GRANULARITY
)

In [23]:
model = CustomGPT2LMHeadModel(config=config)

Sure hope this doesn't get called (get_output_embeddings)


In [24]:
model.transformer.re.weight

Parameter containing:
tensor([[-0.0020,  0.0509,  0.0326,  ..., -0.0155, -0.0382,  0.0135],
        [ 0.0110,  0.0089, -0.0031,  ...,  0.0260,  0.0192,  0.0040],
        [ 0.0079, -0.0057, -0.0017,  ..., -0.0167, -0.0292, -0.0180],
        ...,
        [-0.0125,  0.0027,  0.0290,  ...,  0.0058, -0.0263,  0.0068],
        [ 0.0015,  0.0338,  0.0149,  ...,  0.0110,  0.0037,  0.0067],
        [-0.0172,  0.0127, -0.0336,  ..., -0.0306, -0.0051, -0.0022]],
       requires_grad=True)

In [25]:
training_args = TrainingArguments(
    output_dir=LM_MODEL_SAVEDIR,
    overwrite_output_dir=True,
    num_train_epochs=10,
    per_device_train_batch_size=32,
    save_steps=10000,
    logging_steps=3000,
    evaluation_strategy="steps",
    eval_steps=3000,
    save_total_limit=1,
    prediction_loss_only=False,
    report_to="wandb"
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

In [26]:
ret = trainer.train()

***** Running training *****
  Num examples = 831
  Num Epochs = 1
  Instantaneous batch size per device = 32
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 1
  Total optimization steps = 13
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


[34m[1mwandb[0m: wandb version 0.12.16 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade




Step,Training Loss,Validation Loss




Training completed. Do not forget to share your model on huggingface.co/models =)




In [27]:
trainer.save_model(LM_MODEL_SAVEDIR)

Saving model checkpoint to /home/macosta/ttmp/primus-models/gpt2-lm-semantic-rhythm-v2
Configuration saved in /home/macosta/ttmp/primus-models/gpt2-lm-semantic-rhythm-v2/config.json
Model weights saved in /home/macosta/ttmp/primus-models/gpt2-lm-semantic-rhythm-v2/pytorch_model.bin
