In [1]:
import argparse
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

In [25]:
parser = argparse.ArgumentParser()
parser.add_argument('--num_epochs', default=10, type=int)
parser.add_argument('--batchsize', default=2, type=int, help='number of songs')
parser.add_argument('--noise_dim', default=30, type=int)
parser.add_argument('--hidden_units', default=400, type=int)
parser.add_argument('--num_layers', default=2, type=int)
parser.add_argument('--lr', default=0.1, type=float)
parser.add_argument('--beta1', default=0.9, type=float)
parser.add_argument('--beta2', default=0.999, type=float)
args = parser.parse_args(args=[])
print(args)

Namespace(batchsize=2, beta1=0.9, beta2=0.999, hidden_units=400, lr=0.1, noise_dim=30, num_epochs=10, num_layers=2)


In [3]:
cuda = True if torch.cuda.is_available() else False
print(cuda)

False


In [10]:
class Generator(nn.Module):
    def __init__(self, input_dim=20, hidden_units=400, dropout_prob=0.9, noise_dim=30, midi_dim=3):
        super(Generator, self).__init__()

        self.lstmcell1 = nn.LSTMCell(input_size=input_dim, hidden_size=hidden_units)
        self.lstmcell2 = nn.LSTMCell(input_size=hidden_units, hidden_size=hidden_units)
        self.dropout = nn.Dropout(p=dropout_prob)
        self.fc = nn.Linear(in_features=hidden_units, out_features=midi_dim)

    def forward(self, condition, noise, g_states):
        length, batch_size, input_dim = condition.shape  # L,N,D
        (h1, c1), (h2, c2) = g_states
        output = []

        for len in range(length):
            if len == 0:
                input_dim = input_dim + noise_dim
                h1, c1 = self.lstmcell1(torch.cat((condition[len], noise), dim=1), (h1, c1))
            else:
                input_dim = input_dim + midi_dim
                h1, c1 = self.lstmcell1(torch.cat(condition[len], out), (h1, c1))

            h1 = self.dropout(h1)
            h2, c2 = self.lstmcell2(h1, (h2, c2))
            out = self.fc(h2)
            output.append(out)  # output element should be (N,3)

        gen_midi = torch.stack(output, dim=1)  # (L,N,3)
        #gen_midi = torch.split(gen_midi,1,dim=0)
        #gen_dimi = torch.squeeze(torch.cat(gen_midi,dim=2))  # (N,L*3)
        #g_states = ((h1, c1), (h2, c2))
        return gen_midi

In [45]:
class Discriminator(nn.Module):
    def __init__(self, input_dim=20, hidden_units=400, output_dim=1, num_layers=2, dropout_prob=0.9):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.LSTM(input_size=input_dim, hidden_size=hidden_units, num_layers=num_layers, dropout=dropout_prob),
            nn.Linear(in_features=hidden_units, out_features=output_dim),
            nn.Sigmoid()
        )

    def forward(self, condition, midi, d_state):
        (h3, c3) = d_state
        D_input = torch.cat((condition, midi), dim=2)
        validity = self.model(D_input, (h3, c3))
        return validity

In [6]:
# loss
def bce_loss(input, target):
    neg_abs = - input.abs()
    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
    return loss.mean()

In [7]:
def G_loss(logits_fake):
    size = logits_fake.size()
    true_labels = torch.ones(size).type(FloatTensor)
    loss = bce_loss(logits_fake, true_labels)
    return loss

In [8]:
def D_loss(logits_real, logits_fake):
    size = logits_real.size()
    true_labels = torch.ones(size).type(FloatTensor)
    loss = bce_loss(logits_real, true_lables) + bce_loss(logits_fake, true_lables-1)
    return loss

In [14]:
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    loss.cuda()

In [41]:
train_data = np.load('.\\data\\dataset_matrices\\train_data_matrix.npy')
np.random.shuffle(train_data)
data = train_data  # (11149, 460)
batch_size = 10  # choose first 10 songs to test
data = data[:batch_size]  # (10, 460)
midi = data[:,:60]  # (10, 60)
syll = data[:, 60:]  # (10, 400)

syll = torch.from_numpy(syll)
midi = torch.from_numpy(midi)
print(midi.shape)

torch.Size([10, 60])


In [17]:
optimizer_G = torch.optim.Adam(generator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=args.lr, betas=(args.beta1, args.beta2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor

In [None]:
# 不存fake music先不写

In [42]:
condition = torch.stack(torch.split(syll, 20, dim=1))  # syllable embedding condition
print(condition.shape)
midi = torch.stack(torch.split(midi, 3, dim=1))
print(midi.shape)
noise_dim = args.noise_dim
hidden_units = args.hidden_units
num_layers = args.num_layers
noise = torch.rand(batch_size, noise_dim).type(FloatTensor)

g_states = (h1, c1), (h2, c2)
h3 = torch.rand(num_layers, batch_size, hidden_units).type(FloatTensor)
c3 = torch.rand(num_layers, batch_size, hidden_units).type(FloatTensor)
d_state = (h3, c3)


torch.Size([20, 10, 20])
torch.Size([20, 10, 3])


In [46]:
for epoch in range(args.num_epochs):
    for i in range(batch_size):
        logits_real = discriminator(condition, midi, d_state).type(FloatTensor)
        
        optimizer_G.zero_grad()
        gen_midi = generator(condition, noise, g_states).type(FloatTensor)
        logits_fake = discriminator(condition, gen_midi, d_state)
        g_loss = G_loss(logits_fake)
        g_loss.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        midi = midi.type(FloatTensor)
        logits_real = discriminator(condition, midi, d_state).type(FloatTensor)
        d_loss = D_loss(logits_real, logits_fake)
        d_loss.backward()
        optimizer_D.step()
        
        print(
            "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
            % (epoch, opt.num_epochs, i,batch_size, d_loss.item(), g_loss.item())
        )



TypeError: forward() takes 2 positional arguments but 3 were given

In [3]:
import torch
print(torch.__version__)

1.3.1


In [5]:
a = torch.Tensor(5,3)
a=a.cuda()
print(a)

AssertionError: Torch not compiled with CUDA enabled