In [None]:
# @title Install packages and download weights (takes ~2mins)

!pip uninstall aria-amt
!pip install git+https://github.com/EleutherAI/aria-amt.git
!pip install yt-dlp

import os
import sys

MODEL_NAME = "medium-stacked"
CHECKPOINT_NAME = f"piano-medium-stacked-1.0"

if not os.path.isfile(f"{CHECKPOINT_NAME}.safetensors"):
  !wget https://storage.googleapis.com/aria-checkpoints/amt/{CHECKPOINT_NAME}.safetensors
else:
  print(f"Checkpoint already exists at {CHECKPOINT_NAME} - skipping download")


In [None]:
# Download audio from YouTube

YOUTUBE_LINK = "https://www.youtube.com/watch?v=HZ-TKo2oxHE" # @param Add this yourself after uploading

!yt-dlp --no-playlist --force-overwrites --audio-format mp3 --extract-audio --audio-quality 0 {YOUTUBE_LINK} -o audio.mp3


In [None]:
# Transcribe

print("NOTE: The progress bar tracks transcription of each 10s interval")
print("NOTE: This code will wait for 30s after finishing the transcription")
print("NOTE: Removing the compile flag will remove the initial cost associated with compilation, but will slow down inference\n")

!aria-amt transcribe {MODEL_NAME} {CHECKPOINT_NAME}.safetensors -load_path=audio.mp3 -save_dir=. -bs=1

In [None]:
# Download transcription

from google.colab import files

files.download('audio.mid')

# **run aria-amt on json file of yt links**

In [None]:
# upload json
from google.colab import files

uploaded = files.upload()

In [None]:
# load json file to list / txt file

import json
import os
import sys

def load_json(json_file, path="yt-links.txt"):
    links = []
    # if os.path.isfile(path):
    #     run_again = input("rewrite links text file? (y/n)")
    #     if (run_again.lower() != "y"):
    #         with open(path, 'w') as file:
    #             links = file.readlines()
    #         return links
    try:
        if not os.path.isfile(path):
          txt_file = open(path, 'x')
    except:
        print("rewriting file...")
    with open(json_file) as file:
        for line in file:
            try:
                link = json.loads(line).get("url")
                links.append(link)
                with open(path, 'w') as file:
                  file.write(link)
                print(link)
            except:
                print("ERROR: json line fail")
    return links

In [None]:
# download yt link from json file

def download_from_json(yt_links, i):
    print("downloading links..")
    try:
        x = yt_links[i]
    except:
        print("out of range")
        return
    if not os.path.isfile(f"audio-{i}.mp3"):
      !yt-dlp --no-playlist --force-overwrites --audio-format mp3 --extract-audio --audio-quality 0 {yt_links[i]} -o audio-{i}.mp3
      print("downloaded audio " + yt_links[i])
    else:
      print("already downloaded: " + yt_links[i])

In [None]:
# run aria on mp3 file

def run_aria_amt(path, directory="."):
    !aria-amt transcribe {MODEL_NAME} {CHECKPOINT_NAME}.safetensors -load_path={path} -save_dir={directory} -bs=1

In [None]:
# load/reload json file

FILE_PATH = "test.json" # @param {type:"raw"}
LINKS = load_json(FILE_PATH)

In [None]:
# test yt-dlp & aria-amt
download_from_json(LINKS, 0)

if not os.path.isdir("midi"):
    !mkdir midi

run_aria_amt("audio-0.mp3")

In [None]:
# run on all links in json

START = 0
END = len(LINKS) # not included

if not os.path.isdir("midi"):
    !mkdir midi

for i in range(START, END):
  print("downloading")
  download_from_json(LINKS, i)
  run_aria_amt(f"audio-{i}.mp3")
  try:
    !rm audio-{i}.mp3
  except:
    print("file was not found or could not be removed")

In [None]:
# download midi folder

import shutil

shutil.make_archive("midi", 'zip', "midi")
files.download("midi.zip")