In [1]:
import torch
from torch import nn
from gan.settings import *
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

cuda:0


In [2]:
from gan.dataset import Dataset
from torch.utils.data import DataLoader
dataset = Dataset(file='./gan/dataset_4_bar_strict_large.npy', max_num_files=300)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, drop_last=True)

Dataset shape : (300, 4, 4, 16, 84)


In [3]:
from gan.gan import MuseGenerator, MuseCritic
from gan.utils import initialize_weights

generator = MuseGenerator(z_dim=32, hid_channels=1024, hid_features=1024, out_channels=1).to(device)
g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.001, betas=(0.5, 0.9))
critic = MuseCritic(hid_channels=128,
                    hid_features=1024,
                    out_features=1).to(device)
c_optimizer = torch.optim.Adam(critic.parameters(), lr=0.001, betas=(0.5, 0.9))

generator = generator.apply(initialize_weights)
critic = critic.apply(initialize_weights)

In [None]:
from gan.train import Trainer
trainer = Trainer(generator, critic, g_optimizer, c_optimizer)
num_epochs = 500
trainer.train(dataloader, epochs=num_epochs)
losses = trainer.data.copy()

  0%|          | 0/500 [00:00<?, ?it/s]

Epoch 0/500 | Generator loss: 76.533 | Critic loss: -994.340 (fake: -39.247, real: -955.093, penalty: 460.273)
Epoch 10/500 | Generator loss: 2653.038 | Critic loss: -294.307 (fake: -2266.706, real: 1972.399, penalty: 114.135)


In [None]:
import pandas as pd
import os
generator = generator.eval().cpu()
#critic = critic.eval().cpu()
out_path = './gan/pre_trained_models/modelA'
torch.save(generator, os.path.join(out_path, f'generator_e{num_epochs}_s{len(dataset)}.pt'))
torch.save(critic, os.path.join(out_path, f'critic_e{num_epochs}_s{len(dataset)}.pt'))
losses = trainer.data.copy()
df = pd.DataFrame.from_dict(losses)
df.to_csv('./gan/pre_trained_model/modelA/results.csv', index=False)

In [None]:
import pandas as pd
from gan.utils import plot_losses
df = pd.read_csv('./gan/pre_trained_model/modelA/results.csv')
losses = pd.DataFrame.from_dict(df)
plot_losses(losses)

In [None]:
from gan.utils import parseToMidi

num_samples = 1
nump_tracks = 4
z_dim = 32

generator = torch.load('./gan/pre_trained_model/modelA/results.csv')

chords = torch.randn(num_samples, z_dim)
style = torch.randn(num_samples, z_dim)
melody = torch.randn(num_samples, nump_tracks, z_dim)
groove = torch.randn(num_samples, nump_tracks, z_dim)

sample = generator(chords, style, melody, groove).detach()

midi_out_path = './gan/my_trained_models/modelA'
parseToMidi(sample, midi_out_path, name="My Track")