Skip to content

Commit

Permalink
refactor data (#603)
Browse files Browse the repository at this point in the history
## Description

* 重构数据处理流程和modules,保存行为交由操作员实现, for #592 
* 为上传线程增加安全锁,应该能解决409问题 #565 
* 为上传线程中媒体文件上传部分增加了一些报错log,方便调试
* 设置了上传key的检查,不允许 `.`和`/`和` `开头,不能为全空空格,收尾空格会被删除

---------

Co-authored-by: ZeYi Lin <944270057@qq.com>
  • Loading branch information
SAKURA-CAT and Zeyi-Lin committed Jun 5, 2024
1 parent e634898 commit 21b6f36
Show file tree
Hide file tree
Showing 47 changed files with 2,268 additions and 2,308 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# dat process
soundfile
pillow
matplotlib
numpy

# web server
Expand All @@ -13,7 +14,7 @@ click

# database
ujson
PyYAML
pyyaml
peewee


Expand Down
38 changes: 21 additions & 17 deletions swanlab/api/cos.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from qcloud_cos.cos_threadpool import SimpleThreadPool
from datetime import datetime, timedelta
from typing import List, Dict, Union
from swanlab.data.modules import MediaBuffer
from swanlab.log import swanlog


class CosClient:
Expand All @@ -37,32 +39,34 @@ def __init__(self, data):
)
self.__client = CosS3Client(config)

def upload(self, key: str, local_path):
def upload(self, buffer: MediaBuffer):
"""
上传文件,需要注意的是file_path应该为unix风格而不是windows风格
开头不能有/
:param key: 上传到cos的文件名称
:param local_path: 本地文件路径,一般用绝对路径
:param buffer: 本地文件的二进制数据
"""
key = self.__prefix + '/' + key
self.__client.upload_file(
Bucket=self.__bucket,
Key=key,
LocalFilePath=local_path,
EnableMD5=False,
progress_callback=None
)
key = "{}/{}".format(self.__prefix, buffer.file_name)
try:
swanlog.debug("Uploading file: {}".format(key))
self.__client.put_object(
Bucket=self.__bucket,
Key=key,
Body=buffer.getvalue(),
EnableMD5=False,
# 一年
CacheControl="max-age=31536000",
)
except Exception as e:
swanlog.error("Upload error: {}".format(e))

def upload_files(self, keys: List[str], local_paths: List[str]) -> Dict[str, Union[bool, List]]:
def upload_files(self, buffers: List[MediaBuffer]) -> Dict[str, Union[bool, List]]:
"""
批量上传文件,keys和local_paths的长度应该相等
:param keys: 上传到cos的文件名称集合
:param local_paths: 本地文件路径,需用绝对路径
:param buffers: 本地文件的二进制对象集合
"""
assert len(keys) == len(local_paths), "keys and local_paths should have the same length"
pool = SimpleThreadPool()
for key, local_path in zip(keys, local_paths):
pool.add_task(self.upload, key, local_path)
for buffer in buffers:
self.upload(buffer)
pool.wait_completion()
result = pool.get_result()
return result
Expand Down
20 changes: 8 additions & 12 deletions swanlab/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from swanlab.log import swanlog
import requests

from swanlab.data.modules import MediaBuffer


def decode_response(resp: requests.Response) -> Union[Dict, AnyStr]:
"""
Expand Down Expand Up @@ -158,31 +160,25 @@ def __get_cos(self):
cos = self.get(f"/project/{self.groupname}/{self.projname}/runs/{self.exp_id}/sts")
self.__cos = CosClient(cos)

def upload(self, key: str, local_path):
def upload(self, buffer: MediaBuffer):
"""
上传文件,需要注意的是file_path应该为unix风格而不是windows风格
开头不能有/,即使有也会被去掉
:param key: 上传到cos的文件名称
:param local_path: 本地文件路径,一般用绝对路径
:param buffer: 自定义文件内存对象
"""
if key.startswith("/"):
key = key[1:]
if self.__cos.should_refresh:
self.__get_cos()
return self.__cos.upload(key, local_path)
return self.__cos.upload(buffer)

def upload_files(self, keys: list, local_paths: list) -> Dict[str, Union[bool, List]]:
def upload_files(self, buffers: List[MediaBuffer]) -> Dict[str, Union[bool, List]]:
"""
批量上传文件,keys和local_paths的长度应该相等
:param keys: 上传到cos
:param local_paths: 本地文件路径,需用绝对路径
:param buffers: 文件内存对象
:return: 返回上传结果, 包含success_all和detail两个字段,detail为每一个文件的上传结果(通过index索引对应)
"""
if self.__cos.should_refresh:
swanlog.debug("Refresh cos...")
self.__get_cos()
keys = [key[1:] if key.startswith("/") else key for key in keys]
return self.__cos.upload_files(keys, local_paths)
return self.__cos.upload_files(buffers)

def mount_project(self, name: str, username: str = None) -> ProjectInfo:
self.__username = self.__username if username is None else username
Expand Down
47 changes: 15 additions & 32 deletions swanlab/api/upload/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
上传相关接口
"""
from ..http import get_http, sync_error_handler
from .model import ColumnModel
from typing import List, Tuple, Dict
from .model import ColumnModel, MediaModel, ScalarModel
from typing import List
from swanlab.error import FileError, ApiError
from swanlab.log import swanlog
import json
Expand Down Expand Up @@ -47,47 +47,27 @@ def upload_logs(logs: List[dict], level: str = "INFO"):


@sync_error_handler
def upload_media_metrics(media_metrics: List[Tuple[dict, str, str, str]]):
def upload_media_metrics(media_metrics: List[MediaModel]):
"""
上传指标的媒体数据
:param media_metrics: 媒体指标数据,
每个元素为元组,第一个元素为指标信息,
第二个元素为指标的名称key,经过URI编码
第三个元素为指标类型
第四个元素为media文件夹路径
:param media_metrics: 媒体指标数据集合
"""
http = get_http()
# 需要上传的文件路径[key, local_path]
file_paths: Dict[str, str] = {}
for metric, key, data_type, media_folder in media_metrics:
if data_type == "text":
# 字符串类型没有文件路径
continue
if isinstance(metric["data"], str):
local_path = metric["data"]
metric["data"] = "{}/{}".format(key, metric["data"])
# 将文件路径添加到files_path中
file_paths[metric["data"]] = os.path.join(media_folder, key, local_path)
else:
local_paths = metric['data']
metric['data'] = ["{}/{}".format(key, x) for x in local_paths]
for i, local_path in enumerate(local_paths):
file_paths[metric['data'][i]] = os.path.join(media_folder, key, local_path)
# 上传文件,先上传资源文件,再上传指标信息
keys = list(file_paths.keys())
local_paths = list(file_paths.values())
http.upload_files(keys, local_paths)
buffers = []
for media in media_metrics:
media.buffers and buffers.extend(media.buffers)
http.upload_files(buffers)
# 上传指标信息
http.post(house_url, create_data([x[0] for x in media_metrics], "media"))
http.post(house_url, create_data([x.to_dict() for x in media_metrics], MediaModel.type.value))


@sync_error_handler
def upload_scalar_metrics(scalar_metrics: List[dict]):
def upload_scalar_metrics(scalar_metrics: List[ScalarModel]):
"""
上传指标的标量数据
"""
http = get_http()
data = create_data(scalar_metrics, "scalar")
data = create_data([x.to_dict() for x in scalar_metrics], ScalarModel.type.value)
http.post(house_url, data)


Expand Down Expand Up @@ -153,5 +133,8 @@ def upload_column(columns: List[ColumnModel]):
"upload_media_metrics",
"upload_scalar_metrics",
"upload_files",
"upload_column"
"upload_column",
"ScalarModel",
"MediaModel",
"ColumnModel"
]
94 changes: 92 additions & 2 deletions swanlab/api/upload/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
@Description:
上传请求模型
"""
from enum import Enum
from typing import List

from swanlab.data.modules import MediaBuffer


class ColumnModel:
Expand All @@ -21,12 +25,13 @@ def __init__(self, key, column_type: str, error: dict = None):
:param error: 错误信息,如果错误信息不为None
"""
self.key = key
if column_type == "DEFAULT":
column_type = "FLOAT"
self.column_type = column_type
self.error = error

def to_dict(self):
"""
序列化为Dict
"""
return {
"key": self.key,
"type": self.column_type,
Expand All @@ -35,3 +40,88 @@ def to_dict(self):
"type": self.column_type,
"error": self.error
}


class MetricType(Enum):
"""
指标类型枚举
"""
SCALAR = "scalar"
"""
标量指标
"""
MEDIA = "media"
"""
媒体指标
"""
LOG = "log"
"""
日志数据
"""


class MediaModel:
"""
媒体指标信息上传模型
"""
type = MetricType.MEDIA

def __init__(
self,
metric: dict,
key: str,
key_encoded: str,
step: int,
epoch: int,
buffers: List[MediaBuffer] = None
):
self.metric = metric
self.step = step
self.epoch = epoch
self.key = key
"""
真实的指标名称
"""
self.key_encoded = key_encoded
"""
编码后路径安全的指标名称
"""
self.buffers = buffers
"""
原始数据,可能为None
"""

def to_dict(self):
"""
序列化
"""
return {
**self.metric,
"key": self.key,
"index": self.step,
"epoch": self.epoch
}


class ScalarModel:
"""
标量指标信息上传模型
"""
type = MetricType.SCALAR

def __init__(self, metric: dict, key: str, step: int, epoch: int):
self.metric = metric
self.key = key
self.step = step
self.epoch = epoch

def to_dict(self):
"""
序列化
"""
return {
**self.metric,
"key": self.key,
"index": self.step,
"epoch": self.epoch
}
Loading

0 comments on commit 21b6f36

Please sign in to comment.