Skip to content

Commit

Permalink
HFT Model Download and Upload options - Uploading and Downloading tra…
Browse files Browse the repository at this point in the history
…ined and pre-trained models (#314)

Co-authored-by: Harry Keightley <harrykeightley@outlook.com>
Co-authored-by: Ben Foley <ben@cbmm.io>
  • Loading branch information
3 people committed Oct 14, 2022
1 parent 127ff7a commit 1304d05
Show file tree
Hide file tree
Showing 19 changed files with 4,361 additions and 4,010 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
.tool-versions
# Python
deps/
./env/
venv_test/
*.pyc
Expand Down
101 changes: 92 additions & 9 deletions elpis/endpoints/model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import os
import shutil
import subprocess
from pathlib import Path
from typing import Callable, Dict
from flask import request, current_app as app, jsonify
from ..blueprint import Blueprint

from flask import current_app as app
from flask import jsonify, request, send_file
from loguru import logger
import subprocess
from elpis.engines.common.objects.model import Model
from werkzeug.utils import secure_filename

from elpis.blueprint import Blueprint
from elpis.engines import Interface, ENGINES
from elpis.engines.common.errors import InterfaceError
from elpis.engines.common.objects.model import Model
from elpis.engines.hft.objects.model import TRAINING_STATUS, MODEL_PATH, HFTModel

MISSING_MODEL_MESSAGE = "No current model exists (perhaps create one first)"
MISSING_MODEL_RESPONSE = {"status": 404, "data": MISSING_MODEL_MESSAGE}
Expand All @@ -21,9 +30,7 @@ def run(cmd: str) -> str:
"""Captures stdout/stderr and writes it to a log file, then returns the
CompleteProcess result object"""
args = shlex.split(cmd)
process = subprocess.run(
args, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
process = subprocess.run(args, check=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
return process.stdout


Expand Down Expand Up @@ -54,8 +61,10 @@ def new():
def load():
interface = app.config["INTERFACE"]
model = interface.get_model(request.json["name"])
app.config["CURRENT_DATASET"] = model.dataset
app.config["CURRENT_PRON_DICT"] = model.pron_dict
if model.dataset:
app.config["CURRENT_DATASET"] = model.dataset
if model.pron_dict:
app.config["CURRENT_PRON_DICT"] = model.pron_dict
app.config["CURRENT_MODEL"] = model
data = {"config": model.config._load(), "log": model.log}
return jsonify({"status": 200, "data": data})
Expand Down Expand Up @@ -91,6 +100,8 @@ def list_existing():

@bp.route("/settings", methods=["POST"])
def settings():
logger.info(request.json["settings"])

def setup(model: Model):
model.settings = request.json["settings"]

Expand Down Expand Up @@ -141,6 +152,78 @@ def results():
return jsonify({"status": 200, "data": data})


@bp.route("/download", methods=["GET", "POST"])
def download():
"""Downloads the model files to the frontend"""
model: HFTModel = app.config["CURRENT_MODEL"]
if model is None:
logger.error("No current model exists")
return jsonify(MISSING_MODEL_RESPONSE)

zipped_model_path = Path("/tmp", "model.zip")
logger.info(f"Creating zipped model at path: {zipped_model_path}")
shutil.make_archive(
str(zipped_model_path.parent / zipped_model_path.stem), "zip", model.path / MODEL_PATH
)
logger.info(f"Zipped model created at path: {zipped_model_path}")
try:
return send_file(zipped_model_path, as_attachment=True, cache_timeout=0)
except InterfaceError as e:
return jsonify({"status": 500, "error": e.human_message})


@bp.route("/upload", methods=["POST"])
def upload():
logger.info("Upload endpoint started")
engine = ENGINES["hft"]
interface: Interface = app.config["INTERFACE"]
interface.set_engine(engine)

# Save files to model directory
zip_file = request.files.getlist("file")[0]
filename = secure_filename(str(zip_file.filename))

if filename == "" or Path(filename).suffix != ".zip":
return jsonify({"status": 500, "error": "Invalid filename or not a zip-file"})

try:
model: HFTModel = interface.new_model(Path(filename).stem)
logger.info(f"New model created {model.name} {model.hash}")
app.config["CURRENT_MODEL"] = model
except InterfaceError as e:
return jsonify({"status": 500, "error": e.human_message})

zip_path = model.output_dir / filename
logger.info(f"Saving the zipped model at {zip_path}")
zip_path.parent.mkdir(parents=True, exist_ok=True)

zip_file.save(zip_path)
shutil.unpack_archive(zip_path, model.output_dir)
os.remove(zip_path)
logger.info(f"Zipped model unpacked and deleted")

# Attempts to unpack a zip file if when unzipped, it resolves to a single directory.
folder_path = zip_path.parent / zip_path.stem
if folder_path.exists and folder_path.is_dir():
for file in os.listdir(folder_path):
(folder_path / file).rename(folder_path.parent / file)
folder_path.rmdir()

# Update model state
model.status = TRAINING_STATUS.trained.name
model_list = [
{
"name": model["name"],
"engine_name": model["engine_name"],
"status": model["status"],
}
for model in interface.list_models_verbose()
]

data = {"name": model.name, "list": model_list}
return jsonify({"status": 200, "data": data})


def _model_response(
build_data: Callable[[Model], Dict],
setup: Callable[[Model], None] = (lambda model: None),
Expand Down
4 changes: 3 additions & 1 deletion elpis/engines/common/objects/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,9 @@ def get_model(self, mname):
raise InterfaceError(f'Tried to load a model called "{mname}" that does not exist')
hash_dir = self.config["models"][mname]
m = self.engine.model.load(self.models_path.joinpath(hash_dir))
m.dataset = self.get_dataset(m.config["dataset_name"])
logger.info(f"{m.config}")
if m.config["dataset_name"] is not None:
m.dataset = self.get_dataset(m.config["dataset_name"])
if m.config["pron_dict_name"] is not None:
m.pron_dict = self.get_pron_dict(m.config["pron_dict_name"])
return m
Expand Down
136 changes: 105 additions & 31 deletions elpis/engines/hft/objects/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,26 @@
Support for training Hugging Face Transformers (wav2vec2) models.
"""
import json
from loguru import logger
from pathlib import Path
import os
import random
import re
import string
import sys
import time
import string
from enum import Enum
from dataclasses import dataclass, field
from typing import Any, Dict, List, Set, Optional, Union, Callable
from packaging import version
from os.path import isfile
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Set, Union

import datasets
import numpy as np
from sklearn.model_selection import train_test_split
import torch
import torchaudio
from huggingface_hub import snapshot_download
from loguru import logger
from packaging import version
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.tensorboard import SummaryWriter
from transformers import (
Expand All @@ -38,7 +41,6 @@
from elpis.engines.common.objects.dataset import Dataset
from elpis.engines.common.objects.model import Model as BaseModel


if is_apex_available():
from apex import amp

Expand All @@ -49,6 +51,7 @@

# Used to reduce training time when debugging
DEBUG = False
BASE_MODEL = "facebook/wav2vec2-large-xlsr-53"
QUICK_TRAIN_BUILD_ARGUMENTS = {
"num_train_epochs": "3",
"model_name_or_path": "facebook/wav2vec2-base",
Expand All @@ -64,8 +67,11 @@

TRAINING_STAGES = [TOKENIZATION, PREPROCESSING, TRAIN, EVALUATION]

UNFINISHED = "untrained"
FINISHED = "trained"
TRAINING_STATUS = Enum("TRAINING_STATUS", "untrained trained")

MODEL_PATH = "wav2vec2"
CACHE_DIR = "/state/huggingface_models/"
DOWNLOADED_MODELS = CACHE_DIR + "model_path_index.json"


def list_field(default=None, metadata=None):
Expand All @@ -87,6 +93,10 @@ def __init__(self, **kwargs):
self.config["status"] = "untrained"
self.config["results"] = {}
self.settings = {
"uses_huggingface_api_key": False,
"huggingface_api_token": "",
"uses_custom_model": False,
"huggingface_model_name": "facebook/wav2vec2-large-xlsr-53",
"word_delimiter_token": " ",
"num_train_epochs": 10,
"min_duration_s": 0,
Expand Down Expand Up @@ -126,11 +136,17 @@ def log(self):
with open(self.config["run_log_path"]) as logs:
return logs.read()

@property
def output_dir(self) -> Path:
return self.path / MODEL_PATH

def _set_finished_training(self, has_finished: bool) -> None:
self.status = FINISHED if has_finished else UNFINISHED
self.status = (
TRAINING_STATUS.trained.name if has_finished else TRAINING_STATUS.untrained.name
)

def has_been_trained(self):
return self.status == "trained"
return self.status == TRAINING_STATUS.trained.name

def link(self, dataset: Dataset, _pron_dict):
self.dataset = dataset
Expand All @@ -153,7 +169,7 @@ def get_arguments(self):
"train_size": "0.8",
"split_seed": "42",
"model_name_or_path": "facebook/wav2vec2-large-xlsr-53",
"output_dir": self.path.joinpath("wav2vec2"),
"output_dir": self.output_dir,
"overwrite_output_dir": True,
"num_train_epochs": int(self.settings["num_train_epochs"]),
"per_device_train_batch_size": int(self.settings["batch_size"]),
Expand Down Expand Up @@ -199,14 +215,14 @@ def get_last_checkpoint(self):
"""
last_checkpoint = None
if (
Path(self.training_args.output_dir).is_dir()
self.output_dir.is_dir()
and self.training_args.do_train
and not self.training_args.overwrite_output_dir
):
last_checkpoint = get_last_checkpoint(self.training_args.output_dir)
if last_checkpoint is None and len(os.listdir(self.training_args.output_dir)) > 0:
last_checkpoint = get_last_checkpoint(self.output_dir)
if last_checkpoint is None and len(os.listdir(self.output_dir)) > 0:
raise ValueError(
f"Output directory ({self.training_args.output_dir}) already exists and is not empty. "
f"Output directory ({self.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None:
Expand Down Expand Up @@ -403,21 +419,79 @@ def get_processor(self, feature_extractor, tokenizer):
return Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)

def get_model(self):
return Wav2Vec2ForCTC.from_pretrained(
self.model_args.model_name_or_path,
cache_dir=self.model_args.cache_dir,
activation_dropout=self.model_args.activation_dropout,
attention_dropout=self.model_args.attention_dropout,
hidden_dropout=self.model_args.hidden_dropout,
feat_proj_dropout=self.model_args.feat_proj_dropout,
mask_time_prob=self.model_args.mask_time_prob,
gradient_checkpointing=self.model_args.gradient_checkpointing,
layerdrop=self.model_args.layerdrop,
ctc_loss_reduction="mean",
pad_token_id=self.processor.tokenizer.pad_token_id,
vocab_size=len(self.processor.tokenizer),
ctc_zero_infinity=True,
)
args = []
kwargs = {
"cache_dir": self.model_args.cache_dir,
"activation_dropout": self.model_args.activation_dropout,
"attention_dropout": self.model_args.attention_dropout,
"hidden_dropout": self.model_args.hidden_dropout,
"feat_proj_dropout": self.model_args.feat_proj_dropout,
"mask_time_prob": self.model_args.mask_time_prob,
"gradient_checkpointing": self.model_args.gradient_checkpointing,
"layerdrop": self.model_args.layerdrop,
"ctc_loss_reduction": "mean",
"pad_token_id": self.processor.tokenizer.pad_token_id,
"vocab_size": len(self.processor.tokenizer),
"ctc_zero_infinity": True,
}
if self.settings["uses_custom_model"]:
logger.info("==== Loading a custom model ====")
if not os.path.isdir(CACHE_DIR):
os.makedirs(CACHE_DIR)

logger.info("==== Loading a custom model ====")
# Create the model index if it doesn't already exist
if not os.path.isfile(DOWNLOADED_MODELS):
logger.info("==== Creating custom model index file ====")
with open(DOWNLOADED_MODELS, "w") as model_info:
model_info.write(json.dumps({}))

# Download the current index
logger.info("==== Searching for model within index file ====")
model_name = self.settings["huggingface_model_name"]
with open(DOWNLOADED_MODELS) as model_info:
downloaded_models = json.load(model_info)

# Attempt to find the model name within the index
folder_path = downloaded_models.get(model_name, None)
if folder_path is None:
logger.info("==== Model not found locally :( ====")
logger.info("==== Downloading the custom model from HuggingFace ====")
download_arguments = {
"cache_dir": CACHE_DIR,
}
if self.settings["uses_huggingface_api_key"]:
download_arguments["use_auth_token"] = self.settings["huggingface_api_token"]
logger.info(self.settings["uses_huggingface_api_key"])
logger.info(download_arguments)
folder_path = snapshot_download(
self.settings["huggingface_model_name"], **download_arguments
)
logger.info("==== Downloaded custom model ====")
# Update the custom model index
with open(DOWNLOADED_MODELS, "w") as model_info:
downloaded_models[model_name] = folder_path
model_info.write(json.dumps(downloaded_models))

# Load the downloaded model
logger.info(f"==== Loading model from {folder_path} ====")
pytorch_model = os.path.join(folder_path, "pytorch_model.bin")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
state_dict = torch.load(pytorch_model, map_location=device)
state_dict.pop("lm_head.weight")
state_dict.pop("lm_head.bias")
logger.info(f"==== Model loaded and modified {folder_path} ====")
kwargs["state_dict"] = state_dict
except:
logger.info("This is not a fine-tuned model. Switching to default behaviour.")
args = [folder_path]
else:
logger.info("==== Loading the base/default model ====")
args = [self.model_args.model_name_or_path]

return Wav2Vec2ForCTC.from_pretrained(*args, **kwargs)

def preprocess_dataset(self):
logger.info("==== Preprocessing Dataset ====")
Expand Down

0 comments on commit 1304d05

Please sign in to comment.