<a href="https://colab.research.google.com/github/buganart/descriptor-transformer/blob/main/descriptor_model_train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#@markdown Before starting please save the notebook in your drive by clicking on `File -> Save a copy in drive`

In [None]:
#@markdown Check GPU, should be a Tesla V100
!nvidia-smi -L
import os
print(f"We have {os.cpu_count()} CPU cores.")

In [None]:
#@markdown Mount google drive
from google.colab import drive
from google.colab import output
drive.mount('/content/drive')

from pathlib import Path
if not Path("/content/drive/My Drive/IRCMS_GAN_collaborative_database").exists():
    raise RuntimeError(
        "Shortcut to our shared drive folder doesn't exits.\n\n"
        "\t1. Go to the google drive web UI\n"
        "\t2. Right click shared folder IRCMS_GAN_collaborative_database and click \"Add shortcut to Drive\""
    )

def clear_on_success(msg="Ok!"):
    if _exit_code == 0:
        output.clear()
        print(msg)

In [None]:
#@markdown Install wandb and log in
%pip install wandb
output.clear()
import wandb
from pathlib import Path
wandb_drive_netrc_path = Path("drive/My Drive/colab/.netrc")
wandb_local_netrc_path = Path("/root/.netrc")
if wandb_drive_netrc_path.exists():
    import shutil

    print("Wandb .netrc file found, will use that to log in.")
    shutil.copy(wandb_drive_netrc_path, wandb_local_netrc_path)
else:
    print(
        f"Wandb config not found at {wandb_drive_netrc_path}.\n"
        f"Using manual login.\n\n"
        f"To use auto login in the future, finish the manual login first and then run:\n\n"
        f"\t!mkdir -p '{wandb_drive_netrc_path.parent}'\n"
        f"\t!cp {wandb_local_netrc_path} '{wandb_drive_netrc_path}'\n\n"
        f"Then that file will be used to login next time.\n"
    )

!wandb login
output.clear()
print("ok!")

In [None]:
#@title Configuration

#@markdown Directories can be found via file explorer on the left by navigating into `drive` to the desired folders. 
#@markdown Then right-click and `Copy path`.
audio_db_dir = "/content/drive/My Drive/AUDIO DATABASE/MUSIC TRANSFORMER/Transformer Corpus" #@param {type:"string"}
# audio_db_dir = "/content/drive/My Drive/AUDIO DATABASE/TESTING" #@param {type:"string"}
experiment_dir = "/content/drive/My Drive/IRCMS_GAN_collaborative_database/Experiments/colab-violingan/descriptor-model" #@param {type:"string"}

#@markdown ### Resumption of previous runs
#@markdown Optional resumption arguments below, leaving both empty will start a new run from scratch. 
#@markdown - The ID can be found on wandb. 
#@markdown - It's 8 characters long and may contain a-z letters and digits (for example `1t212ycn`).

#@markdown Resume a previous run 
resume_run_id = "" #@param {type:"string"}

#@markdown train argument
remove_outliers=True#@param {type: "boolean"}
descriptor_size = 5 #@param {type: "integer"}
window_size = 15 #@param {type: "integer"}
learning_rate = 1e-4 #@param {type: "number"}
batch_size = 64 #@param {type: "integer"}
epochs = 3000 #@param {type: "integer"}

# log_interval = 10 #@param {type: "integer"}
save_interval = 10 #@param {type: "integer"}
# n_test_samples = 8 #@param {type: "integer"}


selected_model = "TransformerEncoder" #@param ["LSTM", "LSTMEncoderDecoderModel", "TransformerEncoder"]
notes = "" #@param {type: "string"}
#@markdown model specific argument
#@markdown - LSTM
hidden_size=100 #@param {type: "integer"}
num_layers=3 #@param {type: "integer"}
#@markdown - LSTMEncoderDecoder
forecast_size=10 #@param {type: "integer"}
#@markdown - TransformerEncoder
dim_pos_encoding=50     #@param {type: "integer"}
nhead=5     #@param {type: "integer"}
num_encoder_layers=1    #@param {type: "integer"}
dropout=0.1     #@param {type: "number"}
positional_encoding_dropout=0       #@param {type: "number"}
dim_feedforward=128     #@param {type: "integer"}

import re
from pathlib import Path
from argparse import Namespace

audio_db_dir = Path(audio_db_dir)
experiment_dir = Path(experiment_dir)


for path in [experiment_dir]:
    path.mkdir(parents=True, exist_ok=True)

if not audio_db_dir.exists():
    raise RuntimeError(f"audio_db_dir {audio_db_dir} does not exists.")

def check_wandb_id(run_id):
    if run_id and not re.match(r"^[\da-z]{8}$", run_id):
        raise RuntimeError(
            "Run ID needs to be 8 characters long and contain only letters a-z and digits.\n"
            f"Got \"{run_id}\""
        )

check_wandb_id(resume_run_id)

colab_config = {
    "audio_db_dir": audio_db_dir,
    "experiment_dir": experiment_dir,
    "resume_run_id": resume_run_id,
    "remove_outliers": remove_outliers,
    "descriptor_size": descriptor_size,
    "window_size": window_size,
    "forecast_size": forecast_size,
    "learning_rate": learning_rate,
    "batch_size": batch_size,
    "epochs": epochs,
    "save_interval": save_interval,
    "selected_model": selected_model,
    "notes": notes,
    "hidden_size": hidden_size,
    "num_layers": num_layers,
    "dim_pos_encoding": dim_pos_encoding,
    "nhead": nhead,
    "num_encoder_layers": num_encoder_layers,
    "dropout": dropout,
    "positional_encoding_dropout": positional_encoding_dropout,
    "dim_feedforward": dim_feedforward,
}

for k, v in colab_config.items():
    print(f"=> {k:20}: {v}")

config = Namespace(**colab_config)
config.seed = 1234

if config.selected_model != "LSTMEncoderDecoderModel":
    config.forecast_size = 0
config.window_size = config.window_size + config.forecast_size

In [None]:
#@markdown Install dependency
%pip install --upgrade git+https://github.com/buganart/descriptor-transformer.git#egg=desc
import torch
from desc.train_function import save_model_args, get_resume_run_config, init_wandb_run, setup_datamodule, setup_model, train
clear_on_success()

#Train

In [None]:
run = init_wandb_run(config, run_dir=experiment_dir)#, mode="offline")
datamodule = setup_datamodule(config, run)
model, extra_trainer_args = setup_model(config, run)
if torch.cuda.is_available():
    extra_trainer_args["gpus"] = -1
train(config, run, model, datamodule, extra_trainer_args)