In [8]:
%load_ext autoreload
%autoreload 2
%cd ..

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
/home/teo/userdata


In [9]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from fractions import Fraction

import x_transformers
from omegaconf import OmegaConf
from torch.nn.utils.rnn import pack_sequence, PackedSequence, pad_packed_sequence, pack_padded_sequence, pad_sequence 

from sejong_music.model_zoo import get_emb_total_size
from sejong_music.yeominrak_processing import ShiftedAlignedScore, Tokenizer
from sejong_music.loss import nll_loss



In [10]:
class SumEmbedding(nn.Module):
  def __init__(self, vocab_sizes: dict, emb_size: int) -> None:
    super().__init__()
    self.layers = []
    self.emb_size = emb_size
    for vocab_size in vocab_sizes.values():
      self.layers.append(nn.Embedding(vocab_size, self.emb_size))
    self.layers = nn.ModuleList(self.layers)
  
  def forward(self, x):
    return torch.sum(torch.stack([module(x[..., i]) for i, module in enumerate(self.layers)], dim=-1), dim=-1)

multiemb = SumEmbedding({'a': 10, 'b': 20}, 4)
dummy = torch.randint(0, 9, (17, 31, 2))

In [11]:
def pad_collate_transformer(raw_batch):
  # source, target= zip(*raw_batch)[0], zip(*raw_batch)[1]
    source, target, shi_target = zip(*raw_batch)
    padded_src = pad_sequence(source, batch_first=True, padding_value=0)
    padded_target = pad_sequence(target, batch_first=True, padding_value=0)
    shi_target = pad_sequence(shi_target, batch_first=True, padding_value=0)
    return [padded_src, padded_target, shi_target]

In [6]:
dataset = ShiftedAlignedScore(xml_path='music_score/yeominlak.musicxml')
dataloader = DataLoader(dataset, batch_size=8, shuffle=False, collate_fn=pad_collate)

In [7]:
src, tgt, shi_tgt = next(iter(dataloader))
vocab_size_dict = dataset.tokenizer.vocab_size_dict
vocab_size_dict
config = OmegaConf.load('yamls/transformer.yaml')
# config= get_emb_total_size(config)
src.shape

torch.Size([8, 39, 6])

In [1]:
class TransSeq2seq(nn.Module):
  def __init__(self, tokenizer, config):
    super().__init__()
    self.tokenizer = tokenizer
    self.vocab_size_dict = self.tokenizer.vocab_size_dict
    self.encoder = Encoder(self.vocab_size_dict, config)
    self.decoder = Decoder(self.vocab_size_dict, config)
  
  def forward(self, src, tgt):
    enc_out, src_mask = self.encoder(src)
    dec_out = self.decoder(tgt, enc_out, src_mask)
    return dec_out

class Encoder(nn.Module):
  def __init__(self, vocab_size_dict: dict, config):
    super().__init__()
    self.param = config
    self.vocab_size = [x for x in vocab_size_dict.values()]
    self.vocab_size_dict = vocab_size_dict
    self._make_embedding_layer()
    self.layers = x_transformers.Encoder(dim=self.param.dim ,depth=self.param.depth, num_heads=self.param.num_heads)
    
  def _make_embedding_layer(self):
    self.embedding = SumEmbedding(self.vocab_size_dict, self.param.dim)
    
  def forward(self, x:torch.LongTensor):
    '''
    x: (batch, seq_len, num_features)
    '''
    mask = (x != 0)[..., 0] # squeeze num_features dimension
    embedding = self.embedding(x)
    return self.layers(embedding, mask=mask), mask

class Decoder(nn.Module):
  def __init__(self, vocab_size_dict: dict, config):
    super().__init__()
    self.vocab_size = [x for x in vocab_size_dict.values()]
    self.param = config
    self.embedding = SumEmbedding(vocab_size_dict, config.dim)
    self.layers = x_transformers.Decoder(dim=config.dim, depth=config.depth, num_heads=config.num_heads, cross_attend=True)
    self._make_projection_layer()
    
  def forward(self, x, enc_out, src_mask):
    mask = (x != 0)[..., 0]
    embedding = self.embedding(x)
    output = self.layers(embedding, context=enc_out, mask=mask, context_mask=src_mask)
    logit = self.proj(output)
    dec_out = self._apply_softmax(logit)
    return dec_out
  
  def _make_projection_layer(self):
    # total_vocab_size = sum([x for x in self.vocab_size_dict.values()])
    self.proj = nn.Linear(self.param.dim, self.vocab_size[1] + self.vocab_size[2])

  def _apply_softmax(self, logit):
    pitch_output = torch.softmax(logit[...,:self.vocab_size[1]], dim=-1)
    # print(logit.shape)
    dur_output = torch.softmax(logit[...,self.vocab_size[1] :], dim=-1)
    # print(f"{pitch_output.shape, dur_output.shape}!!")
    return torch.cat([pitch_output, dur_output], dim=-1)
  
  def _select_token(self, prob: torch.Tensor):
    tokens = []
    # print(prob.shape)
    pitch_token = prob[0, :, :self.vocab_size[1]].multinomial(num_samples=1)
    dur_token = prob[0, :, self.vocab_size[1]:].multinomial(num_samples=1)
    # dur_token = torch.argmax(prob.data[:, model.vocab_size[0]:], dim=-1)
    return torch.cat([pitch_token, dur_token], dim=-1)
  
model = TransSeq2seq(dataset.tokenizer, config.model)
prob = model(src, tgt)

NameError: name 'nn' is not defined

In [46]:
shi_tgt[..., 1]

tensor([[12,  5, 12,  ...,  0,  0,  0],
        [ 9, 12,  5,  ...,  0,  0,  0],
        [ 7, 10,  9,  ...,  0,  0,  0],
        ...,
        [ 5,  6,  6,  ...,  0,  0,  0],
        [ 5,  6,  6,  ...,  0,  0,  0],
        [ 8,  6,  8,  ...,  0,  0,  0]])

In [118]:
prob == prob.data

tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

In [86]:
def nll_loss(pred, target, eps=1e-8):
  if pred.ndim == 3:
    pred = pred.flatten(0, 1)
  if target.ndim > 1:
    target = target.flatten()
  mask = target != 0
  return (-torch.log(pred[torch.arange(len(target)), target] + eps) * mask) / mask.sum()


In [109]:
dummy = torch.randint(0, 6, (3,3,6))
model.encoder(dummy)

(tensor([[[ 1.7003,  0.1071,  0.9161,  0.3386, -0.2893,  1.2277,  0.4057,
           -0.7975,  0.9658, -1.0138,  0.8950,  1.5130, -0.2053,  0.2745,
           -0.4462, -0.4711,  0.1771,  1.4812, -1.4967,  0.9481,  0.3569,
            0.1669, -0.8546, -0.7160,  0.3229,  0.6983, -0.2871, -1.6515,
           -0.0788, -1.4268, -1.9439,  0.0548, -0.7643, -1.0192, -2.1964,
           -0.6969, -1.2136,  0.6118,  0.6727, -0.9330,  1.1635, -0.0994,
           -0.8573, -0.0378,  0.4520,  0.8292, -0.2709,  0.8911, -0.6651,
            0.7408, -1.2015, -0.1403, -0.1171, -0.7316, -1.9446,  0.7199,
            0.2564,  0.4507, -0.3684,  0.7075,  1.9839,  0.1413, -0.3816,
            3.1468],
          [-0.1188, -1.1370,  1.0609, -0.1823, -0.4064,  2.0915, -1.0626,
            0.8713,  0.9061, -0.6444, -0.9870,  1.0683,  0.3534, -1.4500,
            0.0784,  1.8547, -0.7242,  1.0078,  0.2001, -1.4425, -0.2128,
           -0.2776, -1.6377,  0.2994, -1.2781, -1.4076, -0.1014,  0.0884,
            0.180

In [94]:
loss_fn = nll_loss
pitch_loss = loss_fn(prob[..., :vocab_size[1]], shi_tgt[..., 1])
dur_loss = loss_fn(prob[..., vocab_size[1]:], shi_tgt[..., 2])
total_loss = (pitch_loss + dur_loss) / 2         

In [None]:
def shifted_inference(self, src, part_idx, prev_generation=None, fix_first_beat=False, compensate_beat=(0.0, 0.0)):
    dev = src.device
    assert src.ndim == 2
    enc_out, enc_mask = self.encoder(src.unsqueeze(0)) 

    # Setup for 0th step
    # start_token = torch.LongTensor([[part_idx, 1, 1, 3, 3, 4]]) # start token idx is 1
    start_token = self.get_start_token(part_idx).to(dev)
  
    #-------------
    # not_triplet_vocab_list = [ i for i, dur in enumerate(self.tokenizer.vocab['duration'][3:], 3) if (dur%0.5) != 0.0]
    #-------------
    measure_duration = self._get_measure_duration(part_idx)
    current_measure_idx = 0
    measure_changed = 1

    final_tokens = [start_token]
    if prev_generation is not None:
      final_tokens, current_measure_idx, last_hidden = self._apply_prev_generation(prev_generation, final_tokens, last_hidden, enc_out)
      # emb = self.decoder._get_embedding(prev_generation)
      # decode_out, last_hidden = self.decoder.rnn(emb.unsqueeze(0), last_hidden)
      # start_token = torch.cat([prev_generation[-1:, :3].to(dev),  torch.LongTensor([[3, 3, 5]]).to(dev)], dim=-1)
      # final_tokens.append(start_token)
      # current_measure_idx = 2
    selected_token = start_token
    
    current_beat = Fraction(0, 1)

    total_attention_weights = []
    # while True:
    triplet_sum = 0 
    condition_tokens = self.encode_condition_token(part_idx, current_beat, current_measure_idx, measure_changed)


    for i in range(300):
      # decode_out, last_hidden = self.decoder.rnn(emb.unsqueeze(0), last_hidden)
      # attention_vectors, attention_weight = self._get_attention_vector(encode_out, decode_out)
      # combined_value = torch.cat([decode_out, attention_vectors], dim=-1)
      # combined_value, last_hidden, attention_weight = self._run_inference_on_step(emb, last_hidden, encode_out)
      prob = self.decoder(torch.cat(final_tokens, dim=0).unsqueeze(0), enc_out, enc_mask)
      selected_token = self.decoder._select_token(prob)
      
      # for rule_base fixing
      if fix_first_beat and current_beat - compensate_beat[1] == 0.0:
        for src_token in src:
          if src_token[-1] == condition_tokens[-1] and self.tokenizer.vocab['offset'][src_token[3].item()] - compensate_beat[0] == 0.0:
            # if measure_idx is same and offset is same
            selected_token[0, 0] = src_token[1] # get pitch from src
            break
      
      if selected_token[0, 0]==self.tokenizer.tok2idx['pitch']['end'] or selected_token[0, 1]==self.tokenizer.tok2idx['duration']['end']:
        break
      
      # update beat position and beat strength
      current_dur = self.tokenizer.vocab['duration'][selected_token[0, 1]]
      if Fraction(current_dur).limit_denominator(3).denominator == 3:
        triplet_sum += Fraction(current_dur).limit_denominator(3).numerator
      elif triplet_sum != 0:
        # print("Triplet has to appear but not", current_dur,)
        triplet_sum += 1
        # print(current_dur, selected_token)
        current_dur = Fraction(1,3)
        selected_token = torch.LongTensor([[selected_token[0,0], self.tokenizer.tok2idx['duration'][current_dur]]]).to(dev)
        # print(current_dur, selected_token)
      else:
        triplet_sum = 0
      if triplet_sum == 3:
        triplet_sum = 0
      current_beat, current_measure_idx, measure_changed = self.update_condition_info(current_beat, current_measure_idx, measure_duration, current_dur)
      if 'measure_idx' in self.tokenizer.tok2idx and current_measure_idx > self.tokenizer.vocab['measure_idx'][-1]:
        break
      if float(current_beat) == 10.0:
        print(f"Current Beat ==10.0 detected! cur_beat: {current_beat}, measure_dur: {measure_duration}, cur_measure_idx: {current_measure_idx}, cur_dur: {current_dur}, triplet_sum: {triplet_sum}")
      condition_tokens = self.encode_condition_token(part_idx, current_beat, current_measure_idx, measure_changed)
      
      # make new token for next rnn timestep
      selected_token = torch.cat([torch.LongTensor([[part_idx]]).to(dev), selected_token, torch.LongTensor([condition_tokens]).to(dev)], dim=1)
      
      final_tokens.append(selected_token)
      
    if len(total_attention_weights) == 0:
      attention_map = None
    else:
      attention_map = torch.stack(total_attention_weights, dim=1)
      
    cat_out = torch.cat(final_tokens[1:], dim=0) #token shape: [1,6] => [24,6]
    newly_generated = cat_out.clone()
    if prev_generation is not None:
      cat_out = torch.cat([prev_generation[1:], cat_out], dim=0)
    #  self.converter(src[1:-1]), self.converter(torch.cat(final_tokens, dim=0)), attention_map

    return self._decode_inference_result(src, cat_out, (attention_map, cat_out, newly_generated))

In [5]:
batch = next(iter(dataloader))
out_a = encoder(batch[0])

modified_x = batch[0].clone()
modified_x[0, 4, 1:] = 3
out_b = encoder(modified_x)

out_a == out_b

NameError: name 'dataloader' is not defined

RuntimeError: The size of tensor a (6) must match the size of tensor b (240) at non-singleton dimension 3