In [1]:
import torch,torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np
from model import Generator, iterate_minibatches, compute_loss, train

  from ._conv import register_converters as _register_converters


In [2]:
OCTAVE_NUM = 4
NOTE_NUM = 12
TIME_SCALE = 128


class LSTM_discriminator(nn.Module):
    def __init__(self,hidden_size = 1000,last_dim = 3):
        super(self.__class__, self).__init__()
        self.last_dim = last_dim
        self.hidden_size = hidden_size
        self.note_lstm = nn.LSTM(input_size = OCTAVE_NUM*last_dim,hidden_size = hidden_size)
        self.time_lstm = nn.LSTM(input_size = hidden_size,hidden_size = hidden_size)
        self.dense = nn.Linear(hidden_size,1)

    def forward(self,data):
        # data.size() =  (batch_size, TIME_SCALE, NOTE_NUM*OCTAVE_NUM, last_dim)
        # octave_data.size() =  (batch_size, TIME_SCALE, NOTE_NUM,OCTAVE_NUM*last_dim)
        batch_size,_,_,_ = data.size()
        octave_data = data.view(batch_size,TIME_SCALE,NOTE_NUM,OCTAVE_NUM,self.last_dim)\
                          .view(batch_size,TIME_SCALE,NOTE_NUM,OCTAVE_NUM*self.last_dim)
            
        # note_lstm_input.size() = (NOTE_NUM, batch_size*TIME_SCALE,OCTAVE_NUM*last_dim)
        note_lstm_input = octave_data.view(batch_size*TIME_SCALE,NOTE_NUM,OCTAVE_NUM*self.last_dim)\
                                     .transpose(0,1)
        # note_lstm_output.size() = (NOTE_NUM,batch_size*TIME_SCALE,hidden_size)
        note_lstm_output, _ = self.note_lstm(note_lstm_input)
        # time_lstm_input.size() = (TIME_SCALE,batch_size,hidden_size)
        time_lstm_input = note_lstm_output[-1].view(batch_size,TIME_SCALE,self.hidden_size)\
                                          .transpose(0,1)\
        # time_lstm_output.size() = (TIME_SCALE,batch_size,1000)
        time_lstm_output, _  = self.time_lstm(time_lstm_input)
        # dense_input.size() = (batch_size,1000)
        dense_input = time_lstm_output[-1]
        # dense_output.size() = (batch_size,1)
        dense_output = self.dense(dense_input)
        probs = F.sigmoid(dense_output)
        return probs
        
        
        

In [3]:
# device = torch.device("cuda:4" if torch.cuda.is_available() else "cpu")
# # device = torch.device("cpu")
# discriminator = LSTM_discriminator(hidden_size=10).to(device)
# np_data = np.random.randn(10,TIME_SCALE,NOTE_NUM*OCTAVE_NUM,3)
# data = torch.FloatTensor(np_data).to(device)
# discriminator(data)

In [4]:
class LSTM_baseline(nn.Module):
    def __init__(self,hidden_size = 1000):
        super(self.__class__, self).__init__()
        self.hidden_size = hidden_size
        self.note_lstm = nn.LSTM(input_size = OCTAVE_NUM*3,hidden_size = hidden_size)
        self.time_lstm = nn.LSTM(input_size = hidden_size,hidden_size = hidden_size)
        self.dense = nn.Linear(hidden_size,1)

    def forward(self,data,_):
        # data.size() =  (batch_size, TIME_SCALE, NOTE_NUM*OCTAVE_NUM, 3)
        # octave_data.size() =  (batch_size, TIME_SCALE, NOTE_NUM,OCTAVE_NUM*3)
        batch_size,_,_,_ = data.size()
        octave_data = data.view(batch_size,TIME_SCALE,NOTE_NUM,OCTAVE_NUM,3)\
                          .view(batch_size,TIME_SCALE,NOTE_NUM,OCTAVE_NUM*3)
            
        # note_lstm_input.size() = (NOTE_NUM, batch_size*TIME_SCALE,OCTAVE_NUM*3)
        note_lstm_input = octave_data.view(batch_size*TIME_SCALE,NOTE_NUM,OCTAVE_NUM*3)\
                                     .transpose(0,1)
        # note_lstm_output.size() = (NOTE_NUM,batch_size*TIME_SCALE,hidden_size)
        note_lstm_output, _ = self.note_lstm(note_lstm_input)
        # time_lstm_input.size() = (TIME_SCALE,batch_size,hidden_size)
        time_lstm_input = note_lstm_output[-1].view(batch_size,TIME_SCALE,self.hidden_size)\
                                          .transpose(0,1)\
        # time_lstm_output.size() = (TIME_SCALE,batch_size,1000)
        time_lstm_output, _  = self.time_lstm(time_lstm_input)
        # dense_input.size() = (batch_size,1000)
        dense_input = time_lstm_output[-1]
        # dense_output.size() = (batch_size,1)
        dense_output = self.dense(dense_input)
        probs = F.sigmoid(dense_output)
        return probs

In [5]:
# discriminator = LSTM_baseline(hidden_size=1000).to(device)
# np_data = np.random.randn(10,TIME_SCALE,NOTE_NUM*OCTAVE_NUM,3)
# data = torch.FloatTensor(np_data).to(device)
# discriminator(data)

In [6]:
class BasicGenerator(nn.Module):
    def __init__(self,hidden_size = 1000):
        super(self.__class__, self).__init__()
        self.dense_in = nn.Linear(TIME_SCALE*NOTE_NUM*OCTAVE_NUM*3,hidden_size)
        self.dense_out = nn.Linear(hidden_size,TIME_SCALE*NOTE_NUM*OCTAVE_NUM*3)

    def forward(self,data,_):
        batch_size,_,_,_ = data.size()
        data = data.view(batch_size,-1)
        hid_data = self.dense_in(data)
        out_data = self.dense_out(hid_data)
        output = F.sigmoid(out_data.view(batch_size, TIME_SCALE, NOTE_NUM*OCTAVE_NUM, 3))
        return output
        

In [7]:
class BasicDiscriminator(nn.Module):
    def __init__(self,hidden_size = 1000):
        super(self.__class__, self).__init__()
        self.dense_in = nn.Linear(TIME_SCALE*NOTE_NUM*OCTAVE_NUM*3,hidden_size)
        self.dense_out = nn.Linear(hidden_size,1)

    def forward(self,data):
        batch_size,_,_,_ = data.size()
        data = data.view(batch_size,-1)
        hid_data = self.dense_in(data)
        out_data = self.dense_out(hid_data)
        output = F.sigmoid(out_data)
        return output

In [8]:
def g_loss(p_fake,sound,in_probs,baseline_pred,eps = 1e-8):
#     probs = sound[:,TIME_SCALE//2:,:,:2]*in_probs[:,TIME_SCALE//2:,:,:2]\
#             +(1-sound[:,TIME_SCALE//2:,:,:2])*(1-in_probs[:,TIME_SCALE//2:,:,:2])
    probs = sound[:,:,:,:2]*in_probs[:,:,:,:2]\
            +(1-sound[:,:,:,:2])*(1-in_probs[:,:,:,:2])
    print(p_fake.mean(),probs.mean())
    return -((probs+eps).log().sum(dim =-1).sum(dim =-1).sum(dim =-1)*(p_fake-baseline_pred)).mean()

# -(in_probs[:,:,:,2].mean()-1)
#     return -(p_fake+eps).log().mean()

# loss = g_loss(discriminator(false_example),sound,data_gen,baseline(x_batch,ch_batch))

#,sound.data,data_gen)


def d_loss(p_fake, p_true,eps = 1e-8):
     return -((1-p_fake+eps).log().mean()-(p_true+eps).log().mean())
    
def bl_loss(bl_pred,real_reward):
    return (bl_pred-real_reward).pow(2).mean()


In [9]:
import torch.utils.data

def sample_sound(data_gen):
    size = data_gen.size()
    rand = torch.rand(*size).cuda()
    sample = (rand<data_gen).type(torch.FloatTensor).cuda()
#     sample[:,:,:,2] = data_gen[:,:,:,2]
    sample[:,:,:,2] = 1
    return sample
    

def train_GAN(generator,discriminator,baseline,X_loader,Y_loader,num_epochs = 3,g_lr = 0.001, d_lr = 0.001,bl_lr = 0.001):
    generator.train()
    discriminator.train()
    g_optimizer = torch.optim.Adam(generator.parameters(),     lr=g_lr)#, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr, betas=(0.5, 0.999))#, betas=(0.5, 0.999))
    bl_optimizer = torch.optim.Adam(baseline.parameters(), lr=bl_lr)
    
    d_losses = []
    g_losses = []
    bl_losses = []
    for epoch in range(num_epochs):
        for [x_batch,ch_batch],[y_batch] in zip(X_loader,Y_loader):
            x_batch = x_batch.cuda()
            ch_batch = ch_batch.cuda()
            y_batch = y_batch.cuda()
            x_batch[:,:,:,2] = 1
            ch_batch[:,:,:,2] = 1
            y_batch[:,:,:,2] = 1
            # Optimize D

            data_gen = generator(x_batch,ch_batch)
            sound = sample_sound(data_gen).data
            #concat_sound = torch.cat([x_batch[:,1:TIME_SCALE//2+1,:,:],sound[:,TIME_SCALE//2:,:,:]],dim = 1)
            loss = d_loss(discriminator(sound), discriminator(y_batch))
            d_optimizer.zero_grad()
            loss.backward()
#             print(loss.grad)
            d_optimizer.step()
            d_losses.append(loss.data.cpu().numpy())
        
            # Optimize BL
            data_gen = generator(x_batch,ch_batch)
            sound = sample_sound(data_gen).data
            loss = bl_loss(baseline(x_batch,ch_batch),discriminator(sound))
            bl_optimizer.zero_grad()
            loss.backward()
            bl_optimizer.step()
            bl_losses.append(loss.data.cpu().numpy())
            
            # Optimize G
            
            data_gen = generator(x_batch,ch_batch)
            sound = sample_sound(data_gen).data
            #concat_sound = torch.cat([x_batch[:,1:TIME_SCALE//2+1,:,:],sound[:,TIME_SCALE//2:,:,:]],dim = 1)
            loss = g_loss(discriminator(sound),sound,data_gen,baseline(x_batch,ch_batch))#,sound.data,data_gen)
            g_optimizer.zero_grad()
            loss.backward()
#             print(loss.grad)
            g_optimizer.step()
            g_losses.append(loss.data.cpu().numpy())
    return generator,discriminator,baseline,np.array(g_losses),np.array(d_losses),np.array(bl_losses)

In [10]:
from dataset import load_all
from constants import *

styles= [['data/Bach1']]
train_data, train_labels = load_all(styles, BATCH_SIZE, TIME_SCALE)
N = 2500
X_tr = train_data[0][:N]
y_tr = train_labels[0][:N]
# X_te = train_data[0][N:2*N]
train_data[0].shape,
# X_te.shape,y_tr.shape,N
#y_te = train_labels[0][-1:]

((81, 128, 48, 3),)

In [11]:
X_loader = torch.utils.data.DataLoader(\
            torch.utils.data.TensorDataset(\
            *(torch.FloatTensor(X_tr),
            torch.FloatTensor(y_tr))),\
            batch_size=10,shuffle=True)
# Y_loader = torch.utils.data.DataLoader(\
#             torch.utils.data.TensorDataset(\
#             torch.FloatTensor(X_te)),\
#             batch_size=10,shuffle=True)

In [14]:
import torch.utils.data

    

def train_GAN2(generator,discriminator,baseline,X_loader,num_epochs = 3,g_lr = 0.001, d_lr = 0.001,bl_lr = 0.001):
    generator.train()
    discriminator.train()
    g_optimizer = torch.optim.Adam(generator.parameters(),     lr=g_lr)#, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=d_lr,weight_decay = 1)#, betas=(0.5, 0.999))
    bl_optimizer = torch.optim.Adam(baseline.parameters(), lr=bl_lr)
    
    d_losses = []
    g_losses = []
    bl_losses = []
    for epoch in range(num_epochs):
        for x_batch,ch_batch in X_loader:
            x_batch = x_batch.cuda()
            ch_batch = ch_batch.cuda()

            x_batch[:,:,:,2] = 1
            ch_batch[:,:,:,2] = 1

            # Optimize D
#             print(1)

            data_gen = generator(x_batch,ch_batch)
            sound = sample_sound(data_gen).data
            #concat_sound = torch.cat([x_batch[:,1:TIME_SCALE//2+1,:,:],sound[:,TIME_SCALE//2:,:,:]],dim = 1)
            false_example = torch.cat([sound,x_batch],dim = -1)
            true_example = torch.cat([x_batch,ch_batch],dim = -1)
            loss = d_loss(discriminator(false_example), discriminator(true_example))
            d_optimizer.zero_grad()
            loss.backward()
#             print(loss.grad)
            d_optimizer.step()
            d_losses.append(loss.data.cpu().numpy())
        
            # Optimize BL
            data_gen = generator(x_batch,ch_batch)
#             sound = sample_sound(data_gen).data
            false_example = torch.cat([sound,x_batch],dim = -1)
            loss = bl_loss(baseline(x_batch,ch_batch),discriminator(false_example))
            bl_optimizer.zero_grad()
            loss.backward()
            bl_optimizer.step()
            bl_losses.append(loss.data.cpu().numpy())
            
            # Optimize G
            
            data_gen = generator(x_batch,ch_batch)
#             sound = sample_sound(data_gen).data
            #concat_sound = torch.cat([x_batch[:,1:TIME_SCALE//2+1,:,:],sound[:,TIME_SCALE//2:,:,:]],dim = 1)
            false_example = torch.cat([sound,x_batch],dim = -1)
            loss = g_loss(discriminator(false_example),sound,data_gen,baseline(x_batch,ch_batch))#,sound.data,data_gen)
            g_optimizer.zero_grad()
            loss.backward()
#             print(loss.grad)
            g_optimizer.step()
            g_losses.append(loss.data.cpu().numpy())
    return generator,discriminator,baseline,np.array(g_losses),np.array(d_losses),np.array(bl_losses)

In [19]:
generator = Generator().cuda()
generator.load_state_dict(torch.load(os.path.join(OUT_DIR, 'model_canonical')))
discriminator = LSTM_discriminator(hidden_size=100, last_dim=6).cuda()
#     # generator = BasicGenerator().cuda()
baseline = LSTM_baseline(hidden_size=10).cuda()
# discriminator = BasicDiscriminator().cuda()

In [20]:
generator,discriminator,baseline,g_losses,d_losses,bl_losses =\
                    train_GAN2(generator,discriminator,baseline,\
                            X_loader,num_epochs = 10, g_lr = 1*1e-4,d_lr=1*1e-2, bl_lr = 1*1e-4)

tensor(0.4960, device='cuda:0') tensor(0.8759, device='cuda:0')
tensor(0.4908, device='cuda:0') tensor(0.8780, device='cuda:0')
tensor(0.4886, device='cuda:0') tensor(0.8793, device='cuda:0')
tensor(0.4881, device='cuda:0') tensor(0.8782, device='cuda:0')
tensor(0.4883, device='cuda:0') tensor(0.8755, device='cuda:0')
tensor(0.4888, device='cuda:0') tensor(0.8716, device='cuda:0')
tensor(0.4893, device='cuda:0') tensor(0.8659, device='cuda:0')
tensor(0.4896, device='cuda:0') tensor(0.8602, device='cuda:0')
tensor(0.4895, device='cuda:0') tensor(0.8590, device='cuda:0')
tensor(0.4888, device='cuda:0') tensor(0.8439, device='cuda:0')
tensor(0.4877, device='cuda:0') tensor(0.8374, device='cuda:0')
tensor(0.4860, device='cuda:0') tensor(0.8261, device='cuda:0')
tensor(0.4841, device='cuda:0') tensor(0.8180, device='cuda:0')
tensor(0.4817, device='cuda:0') tensor(0.8062, device='cuda:0')
tensor(0.4793, device='cuda:0') tensor(0.7955, device='cuda:0')
tensor(0.4767, device='cuda:0') tensor(0

In [21]:
import matplotlib.pyplot as plt
plt.plot(g_losses,label = "Generator loss")
plt.legend()
plt.show()
plt.plot(d_losses,label = "Discriminator loss")
plt.legend()
plt.show()
plt.plot(bl_losses,label = "Baseline loss")
plt.legend()
plt.show()

NameError: name 'g_losses' is not defined

In [22]:
g_losses

array([  184.20493  ,   142.3104   ,   120.018654 ,   114.6346   ,
         108.85241  ,   109.8381   ,   107.95536  ,   108.08594  ,
         101.41591  ,   106.14627  ,    99.209076 ,    97.43999  ,
          87.86245  ,    80.27351  ,    69.82189  ,    57.58156  ,
          42.455685 ,    31.59626  ,    20.003407 ,     6.9552355,
          -4.4558682,   -14.864026 ,   -28.5371   ,   -41.188553 ,
         -53.101532 ,   -65.043465 ,   -80.60386  ,   -98.00824  ,
        -122.73334  ,  -147.31163  ,  -166.10214  ,  -180.55585  ,
        -188.92476  ,  -193.17213  ,  -194.29262  ,  -203.89731  ,
        -208.39008  ,  -220.39517  ,  -227.20186  ,  -235.7833   ,
        -251.84921  ,  -267.08987  ,  -280.80362  ,  -276.62003  ,
        -298.37134  ,  -302.02597  ,  -317.96686  ,  -343.77298  ,
        -366.0577   ,  -387.9134   ,  -424.59454  ,  -432.597    ,
        -473.75946  ,  -509.76776  ,  -518.22876  ,  -567.5733   ,
        -604.7329   ,  -605.6982   ,  -658.8878   ,  -712.1158

In [23]:
from generate import write_file, generate
# import gc
# torch.cuda.empty_cache()
# gc.collect() 

# with torch.cuda.device(GPU):
write_file('output/rl_test', generate(generator, 16))

  0%|          | 1/256 [00:00<00:39,  6.53it/s]

Generating with no styles:


100%|██████████| 256/256 [00:21<00:00, 11.70it/s]

Writing file out/samples/output/rl_test_0.mid



