Copyright 2017 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

# MusicVAE: A Hierarchical Latent Vector Model for Learning Long-Term Structure in Music.
### ___Adam Roberts, Jesse Engel, Colin Raffel, Curtis Hawthorne, and Douglas Eck___

[MusicVAE](https://g.co/magenta/music-vae) learns a latent space of musical scores, providing different modes
of interactive musical creation, including:

* Random sampling from the prior distribution.
* Interpolation between existing sequences.
* Manipulation of existing sequences via attribute vectors.

Examples of these interactions can be generated below, and selections can be heard in our
[YouTube playlist](https://www.youtube.com/playlist?list=PLBUMAYA6kvGU8Cgqh709o5SUvo-zHGTxr).

For short sequences (e.g., 2-bar "loops"), we use a bidirectional LSTM encoder
and LSTM decoder. For longer sequences, we use a novel hierarchical LSTM
decoder, which helps the model learn longer-term structures.

We also model the interdependencies between instruments by training multiple
decoders on the lowest-level embeddings of the hierarchical decoder.

For additional details, check out our [blog post](https://g.co/magenta/music-vae) and [paper](https://goo.gl/magenta/musicvae-paper).
___

This colab notebook is self-contained and should run natively on google cloud. The [code](https://github.com/tensorflow/magenta/tree/master/magenta/models/music_vae) and [checkpoints](http://download.magenta.tensorflow.org/models/music_vae/checkpoints.tar.gz) can be downloaded separately and run locally, which is required if you want to train your own model.

# Basic Instructions

1. Double click on the hidden cells to make them visible, or select "View > Expand Sections" in the menu at the top.
2. Hover over the "`[ ]`" in the top-left corner of each cell and click on the "Play" button to run it, in order.
3. Listen to the generated samples.
4. Make it your own: copy the notebook, modify the code, train your own models, upload your own MIDI, etc.!

# Environment Setup
Includes package installation for sequence synthesis. Will take a few minutes.


In [1]:
#@title Connect to Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
%cd /content/drive/MyDrive/Thesis/Code/Magenta/magenta/

/content/drive/MyDrive/University of Alberta/Thesis/Code/Magenta/magenta


In [3]:
!python --version

Python 3.8.10


In [4]:
!pip install -qU google-cloud note-seq==0.0.2 pyfluidsynth

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m209.6/209.6 KB[0m [31m19.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.6/3.6 MB[0m [31m99.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.6/5.6 MB[0m [31m113.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.2/20.2 MB[0m [31m84.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m51.1/51.1 KB[0m [31m6.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m86.6 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m109.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.1/3.1 MB[0m [31m114.0 MB/s[0m et

In [5]:
import glob

BASE_DIR = "gs://download.magenta.tensorflow.org/models/music_vae/colab2"

print('Installing dependencies...')
!apt-get update -qq && apt-get install -qq libfluidsynth1 fluid-soundfont-gm build-essential libasound2-dev libjack-dev
!pip install -q pyfluidsynth

# Hack to allow python to pick up the newly-installed fluidsynth lib.
# This is only needed for the hosted Colab environment.
import ctypes.util
orig_ctypes_util_find_library = ctypes.util.find_library
def proxy_find_library(lib):
  if lib == 'fluidsynth':
    return 'libfluidsynth.so.1'
  else:
    return orig_ctypes_util_find_library(lib)
ctypes.util.find_library = proxy_find_library


print('Importing libraries and defining some helper functions...')
from google.colab import files

Installing dependencies...
E: Package 'libfluidsynth1' has no installation candidate
Importing libraries and defining some helper functions...


In [6]:
!pip install tensor2tensor

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting tensor2tensor
  Downloading tensor2tensor-1.15.7-py2.py3-none-any.whl (1.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m58.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting gevent
  Downloading gevent-22.10.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (6.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m110.5 MB/s[0m eta [36m0:00:00[0m
Collecting gunicorn
  Downloading gunicorn-20.1.0-py3-none-any.whl (79 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.5/79.5 KB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
Collecting tf-slim
  Downloading tf_slim-1.1.0-py2.py3-none-any.whl (352 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m352.1/352.1 KB[0m [31m38.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting bz2file
  Downloading bz2file-0.98.tar.gz 

In [7]:
!pip install note_seq

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [8]:
!pip install -e .

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Obtaining file:///content/drive/MyDrive/University%20of%20Alberta/Thesis/Code/Magenta/magenta
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting dm-sonnet
  Downloading dm_sonnet-2.0.1-py3-none-any.whl (268 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m268.4/268.4 KB[0m [31m25.3 MB/s[0m eta [36m0:00:00[0m
Collecting librosa<0.8.0,>=0.6.2
  Downloading librosa-0.7.2.tar.gz (1.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m76.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting mido==1.2.6
  Downloading mido-1.2.6-py2.py3-none-any.whl (69 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m69.8/69.8 KB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting mir_eval>=0.4
  Downloading mir_eval-0.7.tar.gz (90 kB)
[2K     [90m━━━━━━━━━━━━━━━━━

In [9]:
!pip install fluidsynth

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fluidsynth
  Downloading fluidsynth-0.2.tar.gz (3.7 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fluidsynth
  Building wheel for fluidsynth (setup.py) ... [?25l[?25hdone
  Created wheel for fluidsynth: filename=fluidsynth-0.2-py3-none-any.whl size=4512 sha256=678d814c8ec8327e240ad6102fbdddbbdf2f9b00c38cdcef7a63c56291771590
  Stored in directory: /root/.cache/pip/wheels/d4/e6/bf/921b2deb780e2681b0e1626a13995e504dbbd455b47e7eedd4
Successfully built fluidsynth
Installing collected packages: fluidsynth
Successfully installed fluidsynth-0.2


In [10]:
import magenta.music as mm
from magenta.models.music_vae import configs
from magenta.models.music_vae.trained_model import TrainedModel
import numpy as np
import os
import tensorflow.compat.v1 as tf

tf.disable_v2_behavior()
# tf.enable_eager_execution()

# Necessary until pyfluidsynth is updated (>1.2.5).
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)

def play(note_sequence):
  mm.play_sequence(note_sequence, synth=mm.fluidsynth)

def interpolate(model, start_seq, end_seq, num_steps, max_length=32,
                assert_same_length=True, temperature=0.5,
                individual_duration=4.0):
  """Interpolates between a start and end sequence."""
  note_sequences = model.interpolate(
      start_seq, end_seq,num_steps=num_steps, length=max_length,
      temperature=temperature,
      assert_same_length=assert_same_length)

  print('Start Seq Reconstruction')
  play(note_sequences[0])
  print('End Seq Reconstruction')
  play(note_sequences[-1])
  print('Mean Sequence')
  play(note_sequences[num_steps // 2])
  print('Start -> End Interpolation')
  interp_seq = mm.sequences_lib.concatenate_sequences(
      note_sequences, [individual_duration] * len(note_sequences))
  play(interp_seq)
  mm.plot_sequence(interp_seq)
  return interp_seq if num_steps > 3 else note_sequences[num_steps // 2]

def download(note_sequence, filename):
  mm.sequence_proto_to_midi_file(note_sequence, filename)
  files.download(filename)

print('Done')

Instructions for updating:
non-resource variables are not supported in the long term


Done


In [11]:
from datetime import datetime
%load_ext tensorboard

In [12]:
from glob import glob

# Notes

A few important functions are:


*   
```
# tf.train.list_variables(checkpoint_path)
```
This will list every variable in the checkpoint including tensors. 


*   
```
# tf.train.load_checkpoint(path).get_variable_to_shape_map()
```


*   
```
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
print_tensors_in_checkpoint_file(path, all_tensors=True, tensor_name=name)
```
This will print all or a specific tensor and their values in checkpoint.










# Persian V1 Dataset


In [13]:
from magenta.models.music_vae import music_vae_mcts_train
import pandas as pd
import shutil

In [82]:
config_name = 'cat-mel_2bar_big'
midi_root= '/content/drive/MyDrive/MIDI/Persian/persian_dataset_v1'
run_dir = './data/tmp/Persian/dataset/test1'
mel_2bar_big_ckpt_path = '/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt'

## Create tfrecords

In [83]:
dataset_path = os.path.join(midi_root, '*.mid')
song_paths = glob(dataset_path)
song_names = [os.path.basename(path)[:-4] for path in song_paths]

In [84]:
df = pd.DataFrame({
    "song_name": song_names,
    "song_path": song_paths
})
df.set_index("song_name", inplace=True)

In [None]:
df.loc['gole_yakh']['song_path']

In [85]:
df

Unnamed: 0_level_0,song_path
song_name,Unnamed: 1_level_1
gole_yakh,/content/drive/MyDrive/MIDI/Persian/persian_da...
delyar,/content/drive/MyDrive/MIDI/Persian/persian_da...
tavalod,/content/drive/MyDrive/MIDI/Persian/persian_da...
dokhtare_boyerahmadi,/content/drive/MyDrive/MIDI/Persian/persian_da...
bidade_zaman,/content/drive/MyDrive/MIDI/Persian/persian_da...
...,...
ba_man_sanama,/content/drive/MyDrive/MIDI/Persian/persian_da...
majnoon_naboodom,/content/drive/MyDrive/MIDI/Persian/persian_da...
soze_rokh,/content/drive/MyDrive/MIDI/Persian/persian_da...
gole_pamchal,/content/drive/MyDrive/MIDI/Persian/persian_da...


In [None]:
for name in song_names:
  dest = os.path.join(midi_root, 'temp', name)
  tfrecord_path = os.path.join(midi_root, 'tfrecords', name + '.tfrecord')
  os.makedirs(dest)
  shutil.copy(df.loc[name]['song_path'], dest)
  os.system(
      f'convert_dir_to_note_sequences \
      --input_dir={dest} \
      --output_file={tfrecord_path}'
  )

## Evaluate songs

In [86]:
df["accuracy"] = None

In [87]:
for name in song_names:
  try:
    res = music_vae_mcts_train.run(
      run_dir=run_dir,
      config=config_name,
      mode='eval',
      hparams='batch_size=1',
      cache_dataset=False,
      examples_path=os.path.join(midi_root, 'tfrecords', f'{name}.tfrecord'),
      ckpt_path=mel_2bar_big_ckpt_path
    ) 
    df.loc[name, 'accuracy'] = res['metrics/accuracy']
  except:
    print(f"Failed to evaluate {name}.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/zmqshell.py", line 536, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2854, in run_cell
    result = self._run_cell(
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 2881, in _run_cell
    return runner(coro)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/async_helpers.py", line 68, in _pseudo_sync_runner
    coro.send(None)
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3057, in run_cell_async
    has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3249, in run_ast_nodes
    if (await

In [88]:
df

Unnamed: 0_level_0,song_path,accuracy
song_name,Unnamed: 1_level_1,Unnamed: 2_level_1
gole_yakh,/content/drive/MyDrive/MIDI/Persian/persian_da...,1.0
delyar,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
tavalod,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
dokhtare_boyerahmadi,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
bidade_zaman,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
...,...,...
ba_man_sanama,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
majnoon_naboodom,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
soze_rokh,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0
gole_pamchal,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.0


In [None]:
non_zero_df_new = df[df['accuracy'] != 0]

In [None]:
np.count_nonzero(df['accuracy'])

0

In [None]:
df.to_csv(os.path.join(midi_root, 'persian_data.csv'), sep='\t')

In [None]:
res = music_vae_mcts_train.run(
    run_dir=run_dir,
    config=config_name,
    mode='eval',
    hparams='batch_size=1',
    cache_dataset=False,
    examples_path=os.path.join(midi_root, 'tfrecords', '[MihanDownload.com].tfrecord'),
    ckpt_path=mel_2bar_big_ckpt_path
)

INFO:tensorflow:Total examples: 0
--- Logging error ---
Traceback (most recent call last):
  File "/usr/lib/python3.8/logging/__init__.py", line 1085, in emit
    msg = self.format(record)
  File "/usr/lib/python3.8/logging/__init__.py", line 929, in format
    return fmt.format(record)
  File "/usr/lib/python3.8/logging/__init__.py", line 668, in format
    record.message = record.getMessage()
  File "/usr/lib/python3.8/logging/__init__.py", line 373, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.8/dist-packages/traitlets/config/application.py", line 992, in launc

KeyboardInterrupt: ignored

In [None]:
len(song_names)

87

In [None]:
song_names

['gole_yakh',
 'delyar',
 'tavalod',
 'dokhtare_boyerahmadi',
 'bidade_zaman',
 'ey_iran',
 'pari_kojaee',
 'esfehan_reng_2',
 'saghiname',
 'ghoghaye_setaregan',
 'age_ye_rooz',
 'age_eshgh_hamine',
 'alamatsoal',
 'che_khube_adam',
 'daryache_noor',
 'delam_ino_bavar_nadare',
 'eay',
 'entezar',
 'eshgh',
 'gharibe_ashena',
 'gol_bi_goldoon',
 'hamsafar',
 'gozashtehaye_door',
 'hala_kheili_dire',
 'jashn_tavalod',
 'khodaye_mastoon',
 'kolbeye_man',
 'makhloogh',
 'mara_beboos',
 'sokute_gham',
 'tavalodet_mobarak',
 'to_in_sine_dele_man',
 'too_in_zamuneh',
 'mahour_charshanbesouri',
 'kuchelere_su_sepmishem_harmonica',
 'jane_maryam_2',
 'char_mezrabe_esfahan',
 'morghe_sahar',
 'golden_dreams',
 'soghati',
 'chargah',
 'khuneye_ma',
 'gole_sangam',
 'ashke_man_hoveyda_shod',
 'bahar_bahare',
 'taghatam_deh_1',
 'navaee',
 'elahehye_naz',
 'parandeye_mohajer',
 'dareneh_jan',
 'gole_goldoon',
 'toloo',
 'esfehan_reng_1',
 'esfehan_overture_1',
 'marde_tanha',
 'sar_oomad_zemestoon

# Old Persian Dataset

In [15]:
from magenta.models.music_vae import music_vae_mcts_train
import pandas as pd
import shutil

In [16]:
config_name = 'cat-mel_2bar_big'
midi_root= '/content/drive/MyDrive/MIDI/Persian/Big'
run_dir = './data/tmp/Persian/dataset/test1'
mel_2bar_big_ckpt_path = '/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt'

## Create tfrecords

In [41]:
dataset_paths = (os.path.join(midi_root, '*.mid'), os.path.join(midi_root, '*.MID'))
song_paths = glob(dataset_paths[0]) + glob(dataset_paths[1])
song_names = [os.path.basename(path)[:-4] for path in song_paths]

In [43]:
len(song_names)

309

In [45]:
df = pd.DataFrame({
    "song_name": song_names,
    "new_name": ["persian_"+str(i) for i in range(len(song_names))],
    "song_path": song_paths,
})
df.set_index("new_name", inplace=True)

In [34]:
df.loc['[MihanDownload.com] (4)']['song_path']

'/content/drive/MyDrive/MIDI/Persian/Big/[MihanDownload.com] (4).mid'

In [79]:
df

Unnamed: 0_level_0,song_name,song_path,accuracy
new_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
persian_0,[MihanDownload.com] (4),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_1,[MihanDownload.com] (3),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375
persian_2,[MihanDownload.com] (5),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_3,[MihanDownload.com] (131),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.5
persian_4,[MihanDownload.com] (58),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.84375
...,...,...,...
persian_304,[MihanDownload.com] (260),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.96875
persian_305,[MihanDownload.com] (304),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_306,[MihanDownload.com] (234),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.0
persian_307,[MihanDownload.com] (261),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375


In [58]:
for i in range(len(song_names)):
  name = f"persian_{i}"
  dest = os.path.join(midi_root, 'temp', name)
  tfrecord_path = os.path.join(midi_root, 'tfrecords', name + '.tfrecord')
  os.makedirs(dest)
  shutil.copy(df.loc[name]['song_path'], os.path.join(dest, name + '.mid'))
  os.system(
      f'convert_dir_to_note_sequences \
      --input_dir={dest} \
      --output_file={tfrecord_path}'
  )

## Evaluate songs

In [59]:
df["accuracy"] = None

In [60]:
for i in range(len(song_names)):
  name = f"persian_{i}"
  try:
    res = music_vae_mcts_train.run(
      run_dir=run_dir,
      config=config_name,
      mode='eval',
      hparams='batch_size=1',
      cache_dataset=False,
      examples_path=os.path.join(midi_root, 'tfrecords', f'{name}.tfrecord'),
      ckpt_path=mel_2bar_big_ckpt_path
    ) 
    df.loc[name, 'accuracy'] = res['metrics/accuracy']
  except:
    print(f"Failed to evaluate {name}.")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
    exec(code, run_globals)
  File "/usr/local/lib/python3.8/dist-packages/ipykernel_launcher.py", line 16, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.8/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/kernelapp.py", line 612, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.8/dist-packages/tornado/platform/asyncio.py", line 149, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 570, in run_forever
    self._run_once()
  File "/usr/lib/python3.8/asyncio/base_events.py", line 1859, in _run_once
    handle._run()
  File "/usr/lib/python3.8/asyncio/events.py", line 81, in _run
    self._context.run(self._callback, *self._args)
  File "/usr/local/lib/python3.8/dist-packages/tornado/ioloop.py", line 690, in <lambda>
    lambda f: s

In [78]:
df

Unnamed: 0_level_0,song_name,song_path,accuracy
new_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
persian_0,[MihanDownload.com] (4),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_1,[MihanDownload.com] (3),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375
persian_2,[MihanDownload.com] (5),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_3,[MihanDownload.com] (131),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.5
persian_4,[MihanDownload.com] (58),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.84375
...,...,...,...
persian_304,[MihanDownload.com] (260),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.96875
persian_305,[MihanDownload.com] (304),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.0
persian_306,[MihanDownload.com] (234),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.0
persian_307,[MihanDownload.com] (261),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375


In [64]:
np.count_nonzero(df['accuracy'])

266

In [80]:
no_zero_df = df[df['accuracy'] != 0]

In [81]:
no_zero_df

Unnamed: 0_level_0,song_name,song_path,accuracy
new_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
persian_1,[MihanDownload.com] (3),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375
persian_3,[MihanDownload.com] (131),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.5
persian_4,[MihanDownload.com] (58),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.84375
persian_6,[MihanDownload.com] (199),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375
persian_8,[MihanDownload.com] (168),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.0
...,...,...,...
persian_303,[MihanDownload.com] (208),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.75
persian_304,[MihanDownload.com] (260),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.96875
persian_306,[MihanDownload.com] (234),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.0
persian_307,[MihanDownload.com] (261),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.9375


In [73]:
df.dtypes

song_name    object
song_path    object
accuracy     object
dtype: object

In [74]:
no_zero_df['accuracy'].astype(float).describe()

count    266.000000
mean       0.851269
std        0.144175
min        0.250000
25%        0.781250
50%        0.906250
75%        0.968750
max        1.000000
Name: accuracy, dtype: float64

In [76]:
df.to_csv(os.path.join(midi_root, 'old_persian_data.csv'), sep='\t')

In [None]:
res = music_vae_mcts_train.run(
    run_dir=run_dir,
    config=config_name,
    mode='eval',
    hparams='batch_size=1',
    cache_dataset=False,
    examples_path=os.path.join(midi_root, 'tfrecords', '[MihanDownload.com].tfrecord'),
    ckpt_path=mel_2bar_big_ckpt_path
)

# Combine Datasets

In [13]:
from magenta.models.music_vae import music_vae_mcts_train
import pandas as pd
import shutil

In [42]:
config_name = 'cat-mel_2bar_big'
midi_root_new = '/content/drive/MyDrive/MIDI/Persian/persian_dataset_v1'
midi_root_old = '/content/drive/MyDrive/MIDI/Persian/Big'
run_dir = './data/tmp/Persian/dataset/test1'
mel_2bar_big_ckpt_path = '/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt'
main_root = '/content/drive/MyDrive/MIDI/Persian/persian_100_v1'

In [27]:
df_new = pd.read_csv(os.path.join(midi_root_new, 'persian_data.csv'), sep='\t', index_col='song_name')
df_new = df_new[df_new['accuracy'] != 0]

In [28]:
df_new

Unnamed: 0_level_0,song_path,accuracy
song_name,Unnamed: 1_level_1,Unnamed: 2_level_1
gole_yakh,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.96875
age_ye_rooz,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.9375
age_eshgh_hamine,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.9375
daryache_noor,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.96875
delam_ino_bavar_nadare,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.96875
entezar,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.96875
eshgh,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.9375
gharibe_ashena,/content/drive/MyDrive/MIDI/Persian/persian_da...,1.0
gol_bi_goldoon,/content/drive/MyDrive/MIDI/Persian/persian_da...,0.9375
hamsafar,/content/drive/MyDrive/MIDI/Persian/persian_da...,1.0


In [33]:
df_new.shape

(24, 2)

In [35]:
100 - 24

76

In [31]:
df_old = pd.read_csv(os.path.join(midi_root_old, 'old_persian_data.csv'), sep='\t', index_col='new_name')
df_old = df_old[df_old['accuracy'] != 0]

In [32]:
df_old

Unnamed: 0_level_0,song_name,song_path,accuracy
new_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
persian_1,[MihanDownload.com] (3),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.93750
persian_3,[MihanDownload.com] (131),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.50000
persian_4,[MihanDownload.com] (58),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.84375
persian_6,[MihanDownload.com] (199),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.93750
persian_8,[MihanDownload.com] (168),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.00000
...,...,...,...
persian_303,[MihanDownload.com] (208),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.75000
persian_304,[MihanDownload.com] (260),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.96875
persian_306,[MihanDownload.com] (234),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.00000
persian_307,[MihanDownload.com] (261),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.93750


In [34]:
df_old.shape

(266, 3)

In [39]:
df_old_selection = df_old.sample(n=76)

In [40]:
df_old_selection

Unnamed: 0_level_0,song_name,song_path,accuracy
new_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
persian_75,[MihanDownload.com] (299),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.00000
persian_74,[MihanDownload.com] (106),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.56250
persian_6,[MihanDownload.com] (199),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.93750
persian_151,[MihanDownload.com] (290),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.78125
persian_39,[MihanDownload.com] (258),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.90625
...,...,...,...
persian_218,[MihanDownload.com] (88),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.43750
persian_105,[MihanDownload.com] (230),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.87500
persian_239,[MihanDownload.com] (7),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,0.87500
persian_159,[MihanDownload.com] (272),/content/drive/MyDrive/MIDI/Persian/Big/[Mihan...,1.00000


In [41]:
df_old_selection['accuracy'].astype(float).describe()

count    76.000000
mean      0.845395
std       0.153848
min       0.250000
25%       0.781250
50%       0.890625
75%       0.937500
max       1.000000
Name: accuracy, dtype: float64

In [55]:
song_names = list()

In [56]:
# move new files to new dir
counter = 0
for name in df_new.index:
  print(name)
  song_names.append(name)
  shutil.copy(os.path.join(midi_root_new, "temp", name, f"{name}.mid"), 
              os.path.join(main_root))
  counter += 1
print(f"{counter} file(s) moved.")

gole_yakh
age_ye_rooz
age_eshgh_hamine
daryache_noor
delam_ino_bavar_nadare
entezar
eshgh
gharibe_ashena
gol_bi_goldoon
hamsafar
gozashtehaye_door
kolbeye_man
sokute_gham
char_mezrabe_esfahan
chargah
parandeye_mohajer
toloo
esfehan_overture_1
marde_tanha
prelude_no5
zarbi_homayoun
yek_goli_saye_kamar
in_del_dige_del_nemishe
bade_bahari
24 file(s) moved.


In [58]:
# move old files to new dir
counter = 0
for name in df_old_selection.index:
  print(name)
  song_names.append(name)
  shutil.copy(os.path.join(midi_root_old, "temp", name, f"{name}.mid"), 
              os.path.join(main_root))
  counter += 1
print(f"{counter} file(s) moved.")

persian_75
persian_74
persian_6
persian_151
persian_39
persian_136
persian_60
persian_286
persian_86
persian_164
persian_198
persian_91
persian_208
persian_279
persian_285
persian_185
persian_220
persian_223
persian_251
persian_242
persian_111
persian_157
persian_155
persian_42
persian_65
persian_201
persian_230
persian_148
persian_189
persian_13
persian_171
persian_301
persian_232
persian_17
persian_179
persian_212
persian_293
persian_35
persian_79
persian_127
persian_62
persian_64
persian_140
persian_87
persian_112
persian_243
persian_103
persian_163
persian_233
persian_294
persian_193
persian_187
persian_236
persian_195
persian_106
persian_282
persian_273
persian_138
persian_271
persian_177
persian_28
persian_137
persian_73
persian_33
persian_162
persian_225
persian_46
persian_10
persian_152
persian_149
persian_52
persian_218
persian_105
persian_239
persian_159
persian_24
76 file(s) moved.


In [60]:
# create new dataset of length 100
df_100 = pd.DataFrame({
    "song_name": song_names,
    "fold": 0
})

In [61]:
df_100

Unnamed: 0,song_name,fold
0,gole_yakh,0
1,age_ye_rooz,0
2,age_eshgh_hamine,0
3,daryache_noor,0
4,delam_ino_bavar_nadare,0
...,...,...
95,persian_218,0
96,persian_105,0
97,persian_239,0
98,persian_159,0


In [66]:
# shuffle the data
df_100 = df_100.sample(frac=1).reset_index(drop=True)

In [73]:
df_100

Unnamed: 0,song_name,fold
0,marde_tanha,1
1,persian_285,1
2,persian_74,1
3,persian_225,1
4,parandeye_mohajer,1
...,...,...
95,persian_282,5
96,persian_152,5
97,persian_218,5
98,toloo,5


# Create folds

In [72]:
# assign fold values
df_100['fold'] = 0
df_100.loc[0:19, 'fold'] = 1
df_100.loc[20:39, 'fold'] = 2
df_100.loc[40:59, 'fold'] = 3
df_100.loc[60:79, 'fold'] = 4
df_100.loc[80:99, 'fold'] = 5

In [92]:
## generate tfrecords for each fold
# tfrecord_path = os.path.join(main_root, "tfrecords")
# os.makedirs(tfrecord_path)
for fold in [1,2,3,4,5]:
  train_path = os.path.join(main_root, "train")
  test_path = os.path.join(main_root, "test")
  os.makedirs(train_path)
  os.makedirs(test_path)

  # send copies of files to either train or test according to their fold
  for i in range(100):
    name = df_100.loc[i, "song_name"]
    if df_100.loc[i, "fold"] == fold:
      shutil.copy(os.path.join(main_root, name+'.mid'), test_path)
    else:
      shutil.copy(os.path.join(main_root, name+'.mid'), train_path)

  # create tfrecord for train data  
  os.system(
    f'convert_dir_to_note_sequences \
    --input_dir={train_path} \
    --output_file={os.path.join(tfrecord_path, f"fold_{fold}_train.tfrecord")}'
  )
  # create tfrecord for test data
  os.system(
    f'convert_dir_to_note_sequences \
    --input_dir={test_path} \
    --output_file={os.path.join(tfrecord_path, f"fold_{fold}_test.tfrecord")}'
  )
  shutil.rmtree(train_path)
  shutil.rmtree(test_path)


In [None]:
for name in song_names:
  dest = os.path.join(midi_root, 'temp', name)
  tfrecord_path = os.path.join(midi_root, 'tfrecords', name + '.tfrecord')
  os.makedirs(dest)
  shutil.copy(df.loc[name]['song_path'], dest)
  os.system(
      f'convert_dir_to_note_sequences \
      --input_dir={dest} \
      --output_file={tfrecord_path}'
  )

### Train

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --mode=$mode \
# --finetune=$finetune \
# --trainable_vars=$trainable_vars \
# --examples_path=$train_example_path \
# --num_steps=$num_steps \
# --hparams=batch_size=$batch_size,learning_rate=$learning_rate

In [None]:
# train_ckpt_paths = glob(run_dir+'train/model.ckpt-*.index')
# ckpt_nums = [int(s[len(run_dir)+len('train/model.ckpt-'):-6]) for s in train_ckpt_paths]

In [None]:
# train_ckpt_paths

### Validation

In [None]:
# for n in ckpt_nums:
#   print("Running validation for checkpoint {} ...".format(n))
#   os.system("python magenta/models/music_vae/new_music_vae_train.py \
#             --config=" + config_name + " \
#             --run_dir=" +run_dir + " \
#             --eval_dir_suffix=train \
#             --mode=eval \
#             --examples_path=" + train_example_path + " \
#             --ckpt_no=" + str(n) + " \
#             --hparams=batch_size=1 \
#             --cache_dataset=False"   
#   )

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --eval_dir_suffix=train \
# --mode=eval \
# --examples_path=$train_example_path \
# --ckpt_no=1398 \
# --hparams=batch_size=1 \
# --cache_dataset=False

### Evaluation

In [None]:
# for n in ckpt_nums:
#   print("Running validation for checkpoint {} ...".format(n))
#   os.system("python magenta/models/music_vae/new_music_vae_train.py \
#             --config=" + config_name + " \
#             --run_dir=" +run_dir + " \
#             --mode=eval \
#             --examples_path=" + eval_example_path + " \
#             --ckpt_no=" + str(n) + " \
#             --hparams=batch_size=1 \
#             --cache_dataset=False"   
#   )

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --mode=eval \
# --examples_path=$eval_example_path \
# --ckpt_no=1398 \
# --hparams=batch_size=1 \
# --cache_dataset=False

In [None]:
# res = music_vae_mcts_train.run(
#     run_dir=run_dir,
#     config=config_name,
#     mode='eval',
#     hparams='batch_size=1',
#     cache_dataset=False,
#     examples_path=train_example_path,
#     ckpt_path=mel_2bar_big_ckpt_path
# )

In [None]:
# res

### Results

In [None]:
# !kill 18229

In [None]:
# %tensorboard --logdir $run_dir

## Finetune last layer

In [None]:
# # Experiment config 
# config_name = 'cat-mel_2bar_big'
# mode = 'train' # mode = {train | eval}
# finetune = 'True' # if mode==train, finetune = {True | False}
# # if finetune==True, a comma-separated list of variable names to be finetuned
# # or 'last_layer' or 'all'.
# trainable_vars = 'last_layer' 
# num_steps = '1000'
# batch_size = '16'
# learning_rate = '0.001'

# run_dir = './data/tmp/Persian/CE_MCTS_test_last_{}'.format(num_steps)
# train_example_path = './data/tfrecord/Persian/dataset-1.tfrecord'
# eval_example_path = './data/tfrecord/Persian/dataset-1.tfrecord'
# ckpt_path = './data/test/ckpt_test'

In [None]:
# # Add datetime info to run_dir
# run_dir += datetime.now().strftime('-%y-%m-%d-%H-%M/')
# print("New run_dir is: ", run_dir)

In [None]:
# run_dir = './data/tmp/Bo_Burnham/finetune_big_lastlayer_2000-22-06-06-09-06/'

###Train

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --mode=$mode \
# --finetune=$finetune \
# --trainable_vars=$trainable_vars \
# --examples_path=$train_example_path \
# --num_steps=$num_steps \
# --ckpt_path=$ckpt_path \
# --hparams=batch_size=$batch_size,learning_rate=$learning_rate

In [None]:
# train_ckpt_paths = glob(run_dir+'train/model.ckpt-*.index')
# ckpt_nums = [int(s[len(run_dir)+len('train/model.ckpt-'):-6]) for s in train_ckpt_paths]

### Validation

In [None]:
# for n in ckpt_nums:
#   print("Running validation for checkpoint {} ...".format(n))
#   os.system("python magenta/models/music_vae/new_music_vae_train.py \
#             --config=" + config_name + " \
#             --run_dir=" +run_dir + " \
#             --eval_dir_suffix=train \
#             --mode=eval \
#             --examples_path=" + train_example_path + " \
#             --ckpt_no=" + str(n) + " \
#             --hparams=batch_size=1 \
#             --cache_dataset=False"   
#   )

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --eval_dir_suffix=train \
# --mode=eval \
# --examples_path=$train_example_path \
# --ckpt_no=2000 \
# --hparams=batch_size=1 \
# --cache_dataset=False

### Evaluation

In [None]:
# for n in ckpt_nums:
#   print("Running validation for checkpoint {} ...".format(n))
#   os.system("python magenta/models/music_vae/new_music_vae_train.py \
#             --config=" + config_name + " \
#             --run_dir=" +run_dir + " \
#             --mode=eval \
#             --examples_path=" + eval_example_path + " \
#             --ckpt_no=" + str(n) + " \
#             --hparams=batch_size=1 \
#             --cache_dataset=False"   
#   )

In [None]:
# !python magenta/models/music_vae/new_music_vae_train.py \
# --config=$config_name \
# --run_dir=$run_dir \
# --mode=eval \
# --examples_path=$eval_example_path \
# --ckpt_no=2000 \
# --hparams=batch_size=1 \
# --cache_dataset=False

### Results

In [None]:
# !kill 10239

In [None]:
# %tensorboard --logdir $run_dir

### Generate


In [None]:
# path = os.path.abspath(run_dir+'train/model.ckpt-850')

In [None]:
# model = TrainedModel(config=configs.CONFIG_MAP[config_name],
#                      batch_size=4,
#                      checkpoint_dir_or_path=path)

In [None]:
# #@title Random Samples

# temperature = 0.52 #@param {type:"slider", min:0.01, max:1.5, step:0.01}
# seqs = model.sample(n=4, length=512, temperature=temperature)

# # play(seqs)


In [None]:
# play(seqs[0])

In [None]:
# play(seqs[1])

In [None]:
# mm.plot_sequence(seqs[1])

In [None]:
# download(seqs[1], filename='persian_finetune_last_1325.mid')

In [None]:
# play(seqs[2])

In [None]:
# play(seqs[3])

# Setup

In [None]:
import sys
from magenta.models.music_vae import music_vae_mcts_train

In [None]:
# Experiment config 
config_name = 'cat-mel_2bar_big'
mode = 'train' # mode = {train | eval}
finetune = 'True' # if mode==train, finetune = {True | False}
# if finetune==True, a comma-separated list of variable names to be finetuned
# or 'last_layer' or 'all'.
trainable_vars = 'all' 
num_steps = '2000'
batch_size = '32'
learning_rate = '0.001'

run_dir = './data/mcts/CE_MCTS_{}'.format(num_steps)
train_example_path = './data/tfrecord/Persian/dataset-1.tfrecord'
eval_example_path = './data/tfrecord/Persian/dataset-1.tfrecord'

In [None]:
# Add datetime info to run_dir
run_dir += datetime.now().strftime('-%y-%m-%d-%H-%M/')
print("New run_dir is: ", run_dir)

New run_dir is:  ./data/mcts/CE_MCTS_2000-23-01-23-05-35/


In [None]:
# train_example_path = './data/tfrecord/Video_game.tfrecord'
# eval_example_path = './data/tfrecord/Video_game.tfrecord'

In [None]:
mel_2bar_big_ckpt_path = '/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt'
save_path = './data/mcts/ckpt_test'
log_path = os.path.join(run_dir, 'log.txt')

In [None]:
import logging
log = logging.getLogger()
if not os.path.exists(run_dir):
    os.makedirs(run_dir)
fh = logging.FileHandler(log_path)
log.addHandler(fh)

In [None]:
MODEL_VARIABLES = [
    'decoder/multi_rnn_cell/cell_0/lstm_cell/bias',
    'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel',
    'decoder/multi_rnn_cell/cell_1/lstm_cell/bias',
    'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel',
    'decoder/multi_rnn_cell/cell_2/lstm_cell/bias',
    'decoder/multi_rnn_cell/cell_2/lstm_cell/kernel',
    'decoder/output_projection/bias',
    'decoder/output_projection/kernel',
    'decoder/z_to_initial_state/bias',
    'decoder/z_to_initial_state/kernel',
    'encoder/cell_0/bidirectional_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/bias',
    'encoder/cell_0/bidirectional_rnn/bw/multi_rnn_cell/cell_0/lstm_cell/kernel',
    'encoder/cell_0/bidirectional_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/bias',
    'encoder/cell_0/bidirectional_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel',
    'encoder/mu/bias',
    'encoder/mu/kernel',
    'encoder/sigma/bias',
    'encoder/sigma/kernel'
]


In [None]:
# Importing gc module
import gc
 
# Returns the number of
# objects it has collected
# and deallocated
# collected = gc.collect()
 
# # Prints Garbage collector
# # as 0 object
# print("Garbage collector: collected",
#           "%d objects." % collected)

# Checkpoint

In [None]:
# def checkpoint_to_variable_list(ckpt_path):
#   # tf.keras.backend.clear_session()
#   ckpt_reader = tf.train.load_checkpoint(ckpt_path)
#   name_shape_list = tf.train.list_variables(ckpt_path)

#   tf.reset_default_graph()
#   var_list = list()
#   var_names = list()
#   for name, _ in name_shape_list:
#     if name in MODEL_VARIABLES:
#       print(name)
#       var_list.append(tf.Variable(ckpt_reader.get_tensor(name), name=name))
#       var_names.append(name)

#   return var_names, var_list

In [None]:
def checkpoint_to_variable_list(ckpt_path):
  # tf.keras.backend.clear_session()
  ckpt_reader = tf.train.load_checkpoint(ckpt_path)
  name_shape_list = tf.train.list_variables(ckpt_path)

  tf.reset_default_graph()
  var_list = list()
  var_names = list()
  for name, _ in name_shape_list:
    if name in MODEL_VARIABLES:
      # print(name)
      var_list.append(ckpt_reader.get_tensor(name))
      var_names.append(name)

  return var_names, var_list

In [None]:
# def variable_list_to_checkpoint(variable_list, save_path):
#   saver = tf.train.Saver(variable_list)
#   sess = tf.Session()
#   sess.run(tf.global_variables_initializer())
#   saver.save(sess, save_path)
#   tf.keras.backend.clear_session()

In [None]:
def variable_list_to_checkpoint(var_names, variable_list, save_path):
  saver = tf.train.Saver([tf.Variable(val, name=name) for val, name in zip(variable_list, var_names)])
  sess = tf.Session()
  sess.run(tf.global_variables_initializer())
  saver.save(sess, save_path)
  tf.keras.backend.clear_session()

In [None]:
# var_names, var_list = checkpoint_to_variable_list(mel_2bar_big_ckpt_path)

In [None]:
# tmp = var_list[-4]

In [None]:
# variable_list_to_checkpoint(var_names, var_list, save_path)

In [None]:
# tf.reset_default_graph()
# var_list = list()
# for name, _ in name_shape_list:
#   if name in MODEL_VARIABLES:
#     print(name)
#     var_list.append(tf.Variable(ckpt_reader.get_tensor(name), name=name))

# saver = tf.train.Saver(var_list)
# sess = tf.Session()
# sess.run(tf.global_variables_initializer())
# saver.save(sess, save_path)

# CE-MCTS


## CE Neighbor Functions

In [None]:
import random
random.seed(0)

In [None]:
# Conceptual expansion neighborhood no. 0
# Multiplying alpha by random number in range [-2, 2] for an index in a layer

def neighbor_0(weights, alphas):
  model_weights = weights.copy()
  model_alphas = alphas.copy()

  for i in range(len(model_weights)): 
    model_weights[i] = np.divide(model_weights[i], model_alphas[i])

  idxa = random.randint(0, len(model_weights) - 1)
  current_layer = model_weights[idxa]
  idxb = random.randint(0, current_layer.shape[0] - 1)
  x = np.random.uniform(-2,2)

  if len(current_layer.shape)==1:
    model_alphas[idxa][idxb] *= x 
  else:
    for i in range(current_layer.shape[1]):
      model_alphas[idxa][idxb][i] *= x

  for i in range(len(model_weights)): 
    model_weights[i] = np.multiply(model_weights[i], model_alphas[i])

  return model_weights, model_alphas

In [None]:
# Conceptual expansion neighborhood no. 1
# Multiplying alpha by random number in range [-2, 2] for a layer

def neighbor_1(weights, alphas):
  model_weights = weights.copy()
  model_alphas = alphas.copy()

  for i in range(len(model_weights)): 
    model_weights[i] = np.divide(model_weights[i], model_alphas[i])

  layer_idx = random.randint(0, len(model_weights) - 1)
  current_layer = model_weights[layer_idx]
  
  x = np.random.uniform(-2,2) * np.ones(shape=current_layer.shape)
  model_alphas[layer_idx] = np.multiply(model_alphas[layer_idx], x)

  for i in range(len(model_weights)): 
    model_weights[i] = np.multiply(model_weights[i], model_alphas[i])

  return model_weights, model_alphas

In [None]:
# Conceptual expansion neighborhood no. 2
# Replace a random f with another random f
# The weights must contain at least two layers of the same shape. 
# (To be more efficient for each layer, there should be at least another layer with the same shape.)

def neighbor_2(weights, alphas):
  model_weights = weights.copy()
  model_alphas = alphas.copy()

  for i in range(len(model_weights)): 
    model_weights[i] = np.divide(model_weights[i], model_alphas[i])

  flag = True
  while flag:
    source_idx = random.randint(0, len(model_weights) - 1)
    idx_choices = [idx for idx in range(len(model_weights)) 
                  if (idx != source_idx and 
                      model_weights[idx].shape == model_weights[source_idx].shape)]
    if len(idx_choices):
      flag = False
      
  target_idx = np.random.choice(idx_choices)
  model_weights[target_idx] = model_weights[source_idx]

  for i in range(len(model_weights)): 
    model_weights[i] = np.multiply(model_weights[i], model_alphas[i])
  
  return model_weights, model_alphas

In [None]:
# Conceptual expansion neighborhood no. 3
# Add a random f and alpha to a random target f and alpha
# The weights must contain at least two layers of the same shape. 
# (To be more efficient for each layer, there should be at least another layer with the same shape.)

def neighbor_3(weights, alphas):
  model_weights = weights.copy()
  model_alphas = alphas.copy()

  for i in range(len(model_weights)): 
    model_weights[i] = np.divide(model_weights[i], model_alphas[i])

  flag = True
  while flag:
    source_idx = random.randint(0, len(model_weights) - 1)
    source_layer = model_weights[source_idx]
    source_alpha = model_alphas[source_idx]

    idx_choices = [idx for idx in range(len(model_weights)) 
                  if (idx != source_idx and 
                      model_weights[idx].shape == source_layer.shape)]
    if len(idx_choices):
      flag = False
      
  target_idx = np.random.choice(idx_choices)
  model_weights[target_idx] += source_layer
  model_alphas[target_idx] += source_alpha

  for i in range(len(model_weights)): 
    model_weights[i] = np.multiply(model_weights[i], model_alphas[i])

  return model_weights, model_alphas

In [None]:
# Conceptual expansion neighborhood no. 4
# Swap two random f and alpha
# The weights must contain at least two layers of the same shape. 
# (To be more efficient for each layer, there should be at least another layer with the same shape.)

def neighbor_4(weights, alphas):
  model_weights = weights.copy()
  model_alphas = alphas.copy()

  for i in range(len(model_weights)): 
    model_weights[i] = np.divide(model_weights[i], model_alphas[i])

  flag = True
  while flag:
    source_idx = random.randint(0, len(model_weights) - 1)
    idx_choices = [idx for idx in range(len(model_weights)) 
                  if (idx != source_idx and 
                      model_weights[idx].shape == model_weights[source_idx].shape)]
    if len(idx_choices):
      flag = False
      
  target_idx = np.random.choice(idx_choices)

  # swap
  model_weights[source_idx], model_weights[target_idx] = model_weights[target_idx], model_weights[source_idx]
  model_alphas[source_idx], model_alphas[target_idx] = model_alphas[target_idx], model_alphas[source_idx]

  for i in range(len(model_weights)): 
    model_weights[i] = np.multiply(model_weights[i], model_alphas[i])

  return model_weights, model_alphas

## MCTS Node Class

The difference in our code is that we work with different checkpoints.

**Caveat:** The memory might become an issue. Might want to use some hacks to not load everything.

In [None]:
# MCTS Node
class MCTSNode:
  def __init__(self, idx, ckpt_path, alpha_values=None, fitness_score=None, 
               parent=None, child_nodes=list()):
    self.idx = idx                        # To keep track of nodes
    self.ckpt_path = ckpt_path            # Checkpoint path for the model corresponding to the node
    # self.f_values = f_values              # List of the weights for each layer
    self.alpha_values = alpha_values      # Model alpha values
    self.fitness_score = fitness_score    # Absolute model score
    self.cummulative_score = None         # Score calculated during the backprop
    self.parent = parent                  # Parent node info
    self.child_nodes = child_nodes        # List of all childs to the current node

    if not fitness_score:
      self.set_fitness()

    if not alpha_values:
      self.alpha_values = list()
      for name, shape in tf.train.list_variables(self.ckpt_path):
        if name in MODEL_VARIABLES:
          self.alpha_values.append(np.ones(shape))
  
  def add_child(self, child):
    self.child_nodes.append(child)
  
  # str representation of the node is: Model id <index>
  def __repr__(self):
    return repr('Node id ' + str(self.idx))
  
  # update cummmulative score
  def update_cummulative_score(self, cummulative_score):
    self.cummulative_score = cummulative_score

  # returns model accuracy on training data (Q: loss vs accuracy)
  def set_fitness(self):
    # with open(os.path.abspath(log_path), mode='a+') as sys.stdout:
    res = music_vae_mcts_train.run(
            run_dir=run_dir,
            config=config_name,
            mode='eval',
            hparams='batch_size=1',
            cache_dataset=False,
            examples_path=train_example_path,
            ckpt_path=mel_2bar_big_ckpt_path,
            log='FATAL'
          )

    self.fitness_score = res['metrics/accuracy']

  def create_neighbor_node(self, id=0):
    save_path = os.path.join(run_dir + f'ckpt/ckpt_{id}')
    names, vars = checkpoint_to_variable_list(self.ckpt_path) #load source variables
    choice = np.random.randint(1, 5)
    if choice == 1:
      print("node generated: neighbor type 1")
      vars, alphas = neighbor_1(vars, self.alpha_values)
    elif choice == 2:
      print("node generated: neighbor type 2")
      vars, alphas = neighbor_2(vars, self.alpha_values)
    elif choice == 3:
      print("node generated: neighbor type 3")
      vars, alphas = neighbor_3(vars, self.alpha_values)
    elif choice == 4:
      print("node generated: neighbor type 4")
      vars, alphas = neighbor_4(vars, self.alpha_values)

    variable_list_to_checkpoint(names, vars, save_path)

    # create the new node
    neighbor_node = MCTSNode(id, save_path, alpha_values=alphas, parent=self, child_nodes = list())
    # print(f"node {neighbor_node} was created as a neighbor to {self}")
    self.add_child(neighbor_node)
    print(f"new noded added to the children of {self} --> {self.child_nodes}")
    del names, vars, alphas
    # garbage collect
    collected = gc.collect()
    print("Garbage collector: collected",
          "%d objects." % collected)

    return neighbor_node
  
  def display_tree(self, root):
    pass

## Train

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
# helper functions

# delete an unwanted tree
def delete_tree(root):
    if root:
      print(f"Deleting {root} tree ...")
      for child in root.child_nodes:
          delete_tree(child)
      del root
      gc.collect()

In [None]:
def explore(rollout_idx, depth, HEAD, best_node, best_fitness):
  explore = True
  print(f'Creating a new branch to {HEAD}\t rollout: {rollout_idx}, depth level: {depth}')
  # create a new neighbor from HEAD and with HEAD as its parent
  current = HEAD.create_neighbor_node(root_id)


  # check against the best node
  if current.fitness_score > best_fitness:
    best_node = current
    best_fitness = current.fitness_score

  # make the current node, the HEAD node
  HEAD = current
  # print(f"Head is {HEAD}")

  print(f'The current node: {HEAD}, fitness: {current.fitness_score} ==> parent node: {HEAD.parent}')
  return HEAD, best_node, best_fitness


In [None]:
def exploit(rollout_idx, depth, HEAD):
  print(f'Exploiting nodes at {HEAD} \t rollout: {rollout_idx}, depth: {depth}')

  # pick the child with the best fitness score
  best_score = - np.inf
  next_node = None
  leaf_nodes = list()
  
  for node in HEAD.child_nodes:
    if node.cummulative_score: 
      if node.cummulative_score > best_score:
        best_Score = node.cummulative_score
        next_node = node
    else: # node is a leaf
      leaf_nodes.append(node)

  if len(leaf_nodes):
    next_node = random.choice(leaf_nodes)
    print(f'A random leaf {next_node} has been chosen.')
  else:
    print(f'There are no leafs. The child with the best cummulative score {next_node} was chosen.')

  HEAD = next_node
  del leaf_nodes
  gc.collect()

  return HEAD


In [None]:
def update_cummulative_score(HEAD):
  tmp_head = HEAD.parent
  while tmp_head and tmp_head != root:
    print(f'Update {tmp_head} commulative score.')
    tmp_head.update_cummulative_score(
        tmp_head.fitness_score + 
        discount_factor * tmp_head.child_nodes[-1].fitness_score
    )
    tmp_head = tmp_head.parent

  del tmp_head

In [None]:
def change_root_to_best(root, best_node):
  previous_root = root
  root = best_node 

  if root.parent!= None:
    print(20*'*' + " Deleting previous root tree... ")
    root.parent.child_nodes.remove(root)
    root.parent = None
    delete_tree(previous_root)
  else:
    print("The root is the best node!")

  return root

In [None]:
# Setup and variables
num_generations = 3 #10
no_of_rollouts = 10 #20
rollout_length = 5 #10
discount_factor = 0.3
epsilon = 0.5

# setup the root of the MCTS tree
root_id = 1
root = MCTSNode(idx=1, ckpt_path=mel_2bar_big_ckpt_path)
root_fitness = root.fitness_score
HEAD = None
# all_nodes = [root_node] # To keep track of all of the nodes

# setup best node
best_node = root
best_fitness = root.fitness_score # represents loss

In [None]:
best_fitness

0.875

In [None]:
# start iterations
for gen in range(num_generations):
  print(33*'=' + f' Generation {gen} ' + 33*'=')
  for rollout_idx in range(no_of_rollouts):
    print(f'Rollout no {rollout_idx} ---> best node: {best_node} with fitness {best_fitness}. root={root}')
    HEAD = root         # used to traverse the tree
    explore_mode = False     # selecting explore/exploit

    for depth in range(rollout_length):
      print(50 * '-')
      if gen == 0 and rollout_idx == 0:
        # at the very beginning we want to create a branch
        explore_mode = True
      p = random.uniform(0, 1)
      if explore_mode == False and p < epsilon: # exploit
        HEAD = exploit(rollout_idx, depth, HEAD)
      else: # explore by adding a chain of rollouts / extend to branch to depth length
        explore_mode = True
        root_id += 1
        HEAD, best_node, best_fitness = explore(rollout_idx, depth, HEAD, best_node, best_fitness)
        gc.collect()
    if explore:
      update_cummulative_score(HEAD)

  # choose the best node as new root
  previous_root = root
  root = best_node 

  if root.parent!= None:
    print(20*'*' + " Deleting previous root tree... ")
    root.parent.child_nodes.remove(root)
    root.parent = None
    delete_tree(previous_root)
  else:
    print("The root is the best node!")
  

Rollout no 0 ---> best node: 'Node id 1' with fitness 0.875. root='Node id 1'
--------------------------------------------------
Creating a new branch to 'Node id 1'	 rollout: 0, depth level: 0
node generated: neighbor type 2
new noded added to the children of 'Node id 1' --> ['Node id 2']
Garbage collector: collected 16634 objects.
The current node: 'Node id 2', fitness: 1.0 ==> parent node: 'Node id 1'
--------------------------------------------------
Creating a new branch to 'Node id 2'	 rollout: 0, depth level: 1
node generated: neighbor type 1
new noded added to the children of 'Node id 2' --> ['Node id 3']
Garbage collector: collected 15522 objects.
The current node: 'Node id 3', fitness: 0.9375 ==> parent node: 'Node id 2'
--------------------------------------------------
Creating a new branch to 'Node id 3'	 rollout: 0, depth level: 2
node generated: neighbor type 4
new noded added to the children of 'Node id 3' --> ['Node id 4']
Garbage collector: collected 15522 objects.
Th

In [None]:
root


NameError: ignored

In [None]:
del root

In [None]:
root

NameError: ignored

In [None]:
root.child_nodes[2].child_nodes

['Node id 2', 'Node id 2', 'Node id 3']

In [None]:
delete_tree(root)

Deleting 'Node id 1' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting 'Node id 2' tree ...
Deleting '

ERROR:root:Internal Python error in the inspect module.
Below is the traceback from this internal error.



Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/IPython/core/interactiveshell.py", line 3326, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-86-72e37384f674>", line 1, in <module>
    delete_tree(root)
  File "<ipython-input-71-a079a3a2261b>", line 8, in delete_tree
    delete_tree(child)
  File "<ipython-input-71-a079a3a2261b>", line 8, in delete_tree
    delete_tree(child)
  File "<ipython-input-71-a079a3a2261b>", line 8, in delete_tree
    delete_tree(child)
  [Previous line repeated 2954 more times]
  File "<ipython-input-71-a079a3a2261b>", line 6, in delete_tree
    print(f"Deleting {root} tree ...")
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/iostream.py", line 404, in write
    self.pub_thread.schedule(lambda : self._buffer.write(string))
  File "/usr/local/lib/python3.8/dist-packages/ipykernel/iostream.py", line 202, in schedule
    if self.thread.is_alive():
  File "/usr/lib/python3.8/thre

RecursionError: ignored

In [None]:
root.child_nodes

['Node id 2', 'Node id 2', 'Node id 2', 'Node id 2', 'Node id 2', 'Node id 2']

In [None]:
HEAD.child_nodes[0].child_nodes

['Node id 2', 'Node id 2', 'Node id 2', 'Node id 2', 'Node id 2', 'Node id 2']

In [None]:
best_node.fitness_score

0.90625

In [None]:
best_fitness

0.9375

In [None]:
def test(t):
  t.idx = 9

In [None]:
test(best_node)

In [None]:
del root, best_node, HEAD

In [None]:
gc.collect()

287

# Checks

In [None]:
var_name = 'encoder/cell_0/bidirectional_rnn/fw/multi_rnn_cell/cell_0/lstm_cell/kernel'
# var_name = 'decoder/output_projection/kernel'

checkpoint_path = '/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt'

In [None]:
Old_ckpt = tf.train.load_checkpoint(checkpoint_path)

In [None]:
import tensorflow as tf
tf.reset_default_graph()
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1')
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2')
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, './data/test/my_test_model')

In [None]:
print(tf.get_default_session())

None


In [None]:
test_ckpt_reader = tf.train.load_checkpoint('./data/test/my_test_model')


In [None]:
tf.train.list_variables('./data/test/my_test_model')

[('w1', [2]), ('w2', [5])]

In [None]:
test_ckpt_reader.get_tensor('w1')

array([-0.929671  , -0.27028295], dtype=float32)

In [None]:
test_ckpt_reader.get_tensor('w1_1')

array([-0.46935463,  1.9606491 ], dtype=float32)

In [None]:
Old_A = tf.train.load_checkpoint(checkpoint_path).get_tensor(var_name)

In [None]:
New_A = tf.train.load_checkpoint(run_dir + 'train/model.ckpt-100').get_tensor(var_name)

In [None]:
Old_A

array([[ 0.01359052, -0.08663205,  0.03307157, ...,  0.07788762,
         0.01152594,  0.24845648],
       [-0.06781328, -0.17682482,  0.03815088, ...,  0.41707203,
         0.17010953, -0.2761376 ],
       [-0.00600539,  0.150931  ,  0.00929637, ...,  0.02380793,
        -0.06370527, -0.233501  ],
       ...,
       [-0.14717636, -0.00371401, -0.04210675, ..., -0.04117652,
        -0.08962385, -0.01789565],
       [-0.00903243,  0.03428619,  0.02984675, ..., -0.01778605,
         0.02633332, -0.04182264],
       [-0.0650212 ,  0.05001441,  0.02747146, ..., -0.08395307,
        -0.09573532,  0.02805939]], dtype=float32)

In [None]:
New_A

array([[ 0.01359052, -0.08663205,  0.03307157, ...,  0.07788762,
         0.01152594,  0.24845648],
       [-0.06781328, -0.17682482,  0.03815088, ...,  0.41707203,
         0.17010953, -0.2761376 ],
       [-0.00600539,  0.150931  ,  0.00929637, ...,  0.02380793,
        -0.06370527, -0.233501  ],
       ...,
       [-0.14717636, -0.00371401, -0.04210675, ..., -0.04117652,
        -0.08962385, -0.01789565],
       [-0.00903243,  0.03428619,  0.02984675, ..., -0.01778605,
         0.02633332, -0.04182264],
       [-0.0650212 ,  0.05001441,  0.02747146, ..., -0.08395307,
        -0.09573532,  0.02805939]], dtype=float32)

In [None]:
(New_A==Old_A).all()

True

In [None]:
tf.train.list_variables(checkpoint_path)

[('beta1_power', []),
 ('beta2_power', []),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/bias', [8192]),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam', [8192]),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1', [8192]),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/kernel', [2650, 8192]),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam', [2650, 8192]),
 ('decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1', [2650, 8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/bias', [8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam', [8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1', [8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/kernel', [4096, 8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam', [4096, 8192]),
 ('decoder/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1', [4096, 8192]),
 ('decoder/multi_rnn_cell/cell_2/lstm_cell/bias', [8192]),
 ('decoder/multi_rnn_cell/cell_2/lstm_cell/bias/Adam', [8192]),
 ('decoder/multi_rnn_cell/cel

In [None]:
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

In [None]:
print_tensors_in_checkpoint_file('/content/drive/MyDrive/Magenta/magenta/data/tmp/persian-finetune-11-21-01/train/model.ckpt-500', all_tensors=True, tensor_name='decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam')

tensor: beta1_power (float32) []
1.18984805e-23
tensor: beta2_power (float32) []
0.6057766
tensor: decoder/multi_rnn_cell/cell_0/lstm_cell/bias (float32) [8192]
[-5.46724489e-03  4.73049423e-03 -9.11294576e-03 ... -1.55035285e-02
 -6.62921369e-03  3.34413999e-05]
tensor: decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam (float32) [8192]
[ 7.5676769e-07  1.7772088e-08  7.1647941e-07 ... -8.9731898e-08
 -1.5525816e-07  1.7042743e-07]
tensor: decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1 (float32) [8192]
[1.0983309e-11 7.4119937e-13 6.5951307e-12 ... 2.3710599e-12 3.7374895e-12
 1.2077212e-11]
tensor: decoder/multi_rnn_cell/cell_0/lstm_cell/kernel (float32) [2650, 8192]
[[-0.02101763 -0.0324685  -0.06040208 ...  0.0023125  -0.0393163
  -0.02289756]
 [-0.10999614 -0.05015213 -0.11449713 ... -0.02922615 -0.02492424
  -0.08833413]
 [ 0.00557702  0.01264934 -0.01215714 ...  0.00178084 -0.02203693
  -0.02272074]
 ...
 [-0.02724702 -0.01943417 -0.02045782 ...  0.01954178  0.00792572
  -0.

In [None]:
tf.train.load_checkpoint('/content/drive/MyDrive/Code/cat-mel_2bar_big.ckpt').get_variable_to_shape_map()

{'beta1_power': [],
 'beta2_power': [],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/bias': [8192],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam': [8192],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/bias/Adam_1': [8192],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel': [2650, 8192],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam': [2650, 8192],
 'decoder/multi_rnn_cell/cell_0/lstm_cell/kernel/Adam_1': [2650, 8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/bias': [8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam': [8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/bias/Adam_1': [8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel': [4096, 8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam': [4096, 8192],
 'decoder/multi_rnn_cell/cell_1/lstm_cell/kernel/Adam_1': [4096, 8192],
 'decoder/multi_rnn_cell/cell_2/lstm_cell/bias': [8192],
 'decoder/multi_rnn_cell/cell_2/lstm_cell/bias/Adam': [8192],
 'decoder/multi_rnn_cell/cell_2/lstm_cell/bias/Adam_1': [8192