In [None]:
import numpy as np
import pandas as pd
import json

In [None]:
# import encoder and decoder model
from cau_prediction import Encoder,Decoder
from cau_prediction import cal_conv_out_size
from utils import unpack_music_features,load_data,get_song_idx,get_song_info,load_CAU_dict,get_beat_of_cau,align_cau_gt,align_cau_gen

In [None]:
import torch
from torch.utils.data import Dataset,DataLoader
from torch import nn, optim
from torchvision import transforms, datasets
import torch.nn.functional as F

In [None]:
def get_beat_of_cau(cau_file,cau_idx_to_name,cau_name,idx=False):
    if cau_name in ('HOLD','SOD','EOD','NIL',41,42,43,44):
        beat = 4
    elif idx== True:
        beat = cau_file.loc[cau_file['movement_tag'] == cau_idx_to_name[cau_name+1], 'beats'].values[0]
    else:
        beat =  cau_file.loc[cau_file['movement_tag'] == cau_name, 'beats'].values[0]
    return beat

def align_cau_gen(song_len, cau_seq, prob_seq):
    
    aligned_cau_gen = []
    aligned_prob_gen = []
    
    beat_idx = 0
    for i in range(len(cau_seq)):
        cau = cau_seq[i]
        prob = prob_seq[i]
        beats_len = get_beat_of_cau(cau_file,cau_idx_to_name,cau,True) * interval
        for j in range(int(beat_idx),int(beat_idx+beats_len)):
            if j < song_len:
                aligned_cau_gen.append(cau)
                aligned_prob_gen.append(prob)
            else:
                break
        beat_idx+=beats_len
    for blanks in range(int(beat_idx),song_len):
        aligned_cau_gen.append(41)
        aligned_prob_gen.append(torch.zeros(45))
    return aligned_cau_gen,aligned_prob_gen

def align_cau_gt(song_len, cau_seq):
    aligned_cau_gt = []    
    beat_idx = 0
    for i in range(len(cau_seq)):
        cau = cau_seq[i]
        beats_len = get_beat_of_cau(cau_file,cau_idx_to_name,cau,True) * interval
        for j in range(int(beat_idx),int(beat_idx+beats_len)):
            if j < song_len:
                aligned_cau_gt.append(cau)
            else:
                break
        beat_idx+=beats_len
    for blanks in range(int(beat_idx),song_len):
        aligned_cau_gt.append(41)
    return aligned_cau_gt

In [None]:
def get_chroma_interval(chroma,t,chroma_len=200):
    t=int(t/10)
    chroma_train = chroma[int(t)-int(chroma_len/2):int(t)+int(chroma_len/2)]
    chroma_train = torch.from_numpy(chroma_train)
    chroma_train = chroma_train.reshape(12,chroma_len).unsqueeze(0)
    return chroma_train

In [None]:
def get_beat_note_interval(beat_note,t,beat_note_len=2000):
    t = int(t)
    beat_note_train = beat_note[int(t)-int(beat_note_len/2):int(t)+int(beat_note_len/2)]
    beat_note_train = torch.from_numpy(beat_note_train)
    beat_note_train = beat_note_train.reshape(2,beat_note_len).unsqueeze(0)
    return beat_note_train

In [None]:
class Encoder_Decoder(nn.Module):
    def __init__(self, encoder,decoder,sliding_window_size):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.sliding_window_size = sliding_window_size
        
    def forward(self,chroma,beat_note,last_output,cau_dict,start_t,end_t,beat_interval):
        t = start_t
        cau_gen=[]
        this_cau_gen = 42 #'SOD'
        output_concat = []
        # init hidden layer for decoder
        hidden = decoder.init_hidden()
        while this_cau_gen!= 43 and t + 1000 <= end_t:
            # retrive musical features
            chroma_interval = get_chroma_interval(chroma,t,chroma_len=200)
            beat_note_interval = get_beat_note_interval(beat_note,t,beat_note_len=2000)
            chroma_encoded,beat_note_encoded = self.encoder(chroma_interval.float(), beat_note_interval.float())
            acoustic = torch.from_numpy(np.concatenate([chroma_encoded.detach().numpy(),beat_note_encoded.detach().numpy()],axis=1))
            
            # decoder
            output,hidden = decoder(last_output, hidden, acoustic)
            cau_id = np.argmax(np.exp(output.detach().numpy()))
            this_cau_gen = cau_id
            
            # update t, output and save generated CAU sequence
            cau_gen.append(this_cau_gen)
            t += get_beat_of_cau(cau_file,cau_idx_to_name,this_cau_gen,True) * beat_interval
            last_output = torch.from_numpy(np.array(int(this_cau_gen)))
            output_concat.append(np.exp(output.detach().numpy()))
        return cau_gen,output_concat

## Training

Tuning parameters:
* sliding_window_size
* out_len of musical features
* learning rate
* dropout rate
* optimizer (RMSProp)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
epochs = 100
lr = 0.01
out_len = 30
encoder = Encoder(chroma_len=200, beat_note_len=2000, out_len=out_len, dropout_p=0.2)
decoder = Decoder(hidden_size=128, output_size=45, acoustic_size=out_len*2, dropout_p=0.2)
model = Encoder_Decoder(encoder,decoder,sliding_window_size=1000)
criteon = nn.NLLLoss()
optimizer = optim.RMSprop(model.parameters(), lr=lr)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,factor=0.9,patience=8)

In [None]:
cau_file = pd.read_csv('movement_interval.csv')
cau_idx_to_name,cau_name_to_idx = load_CAU_dict('C')
genre = 'C'

In [None]:
_for epoch in range(epochs):
    genre='C'
    for dance_idx in (get_song_idx(genre)[1]):
#         print('==> Training',genre,dance_idx)
        # load features
        chroma_vector,beat_vector,onset_vector = unpack_music_features('./music_features/',genre,dance_idx+1)
        beat_note_vector = np.concatenate([beat_vector.reshape(-1,1),onset_vector.reshape(-1,1)],axis=1)
        last_output = torch.from_numpy(np.array(44))
        dance_name, _, gt_cau, start_pos, end_pos, interval = get_song_info(cau_file,cau_idx_to_name,genre,dance_idx)
        song_len = beat_vector.shape[0]
        cau_gen,prob_gen = model(chroma=chroma_vector,
                               beat_note=beat_note_vector,
                               last_output=last_output,
                               cau_dict=cau_name_to_idx,
                               start_t = 1000,
                               end_t = song_len,
                               beat_interval = interval)
        cau_gen = np.array(cau_gen)
        print(cau_gen)
        prob_gen = torch.from_numpy(np.array(prob_gen).reshape(len(cau_gen),45))
        aligned_cau_gen,aligned_prob_gen = align_cau_gen(song_len, cau_gen, prob_gen)
        aligned_prob_gen = np.array([t.numpy() for t in aligned_prob_gen]).reshape(len(aligned_cau_gen),45)
        aligned_prob_gen = torch.tensor(aligned_prob_gen,requires_grad=True)
        
        # prep ground truth
        gt_cau = [cau_name_to_idx[cau] for cau in gt_cau]
        aligned_gt_cau = align_cau_gt(song_len,gt_cau)
        aligned_gt_cau = torch.from_numpy(np.array(aligned_gt_cau)) 
                
        # loss
        loss = criteon(aligned_prob_gen,aligned_gt_cau)

        # backprop
        optimizer.zero_grad()
        loss.backward()
#         optimizer.step()
        scheduler.step(loss)
        print('== Epoch',epoch,'==','Loss', loss.item())