文件读写


读取和保存数据

In [None]:
#从硬盘中读取
from mmengine import load, dump

# 从文件中读取数据
data = load('test.json')
data = load('test.yaml')
data = load('test.pkl')
# 从文件对象中读取数据
with open('test.json', 'r') as f:
    data = load(f, file_format='json')

# 将数据序列化为字符串
json_str = dump(data, file_format='json')

# 将数据保存至文件 (根据文件名后缀反推文件类型)
dump(data, 'out.pkl')

# 将数据保存至文件对象
with open('test.yaml', 'w') as f:
    data = dump(data, f, file_format='yaml')

#和python自带的load和dump十分相似，可以直接载入、载出文件，也可以作用于文件对象

拓展支持的文件格式 \
写一个继承自 BaseFileHandler 的文件句柄，并使用一个或者多个文件格式来注册它

In [None]:
from mmengine import register_handler, BaseFileHandler

# 为了注册多个文件格式，可以使用列表作为参数。
# @register_handler(['txt', 'log'])
@register_handler('txt')
class TxtHandler1(BaseFileHandler):

    def load_from_fileobj(self, file):
        return file.read()

    def dump_to_fileobj(self, obj, file):
        file.write(str(obj))

    def dump_to_str(self, obj, **kwargs):
        return str(obj)
    
#支持pickle
from mmengine import BaseFileHandler
import pickle

class PickleHandler(BaseFileHandler):

    def load_from_fileobj(self, file, **kwargs):
        return pickle.load(file, **kwargs)

    def load_from_path(self, filepath, **kwargs):
        return super(PickleHandler, self).load_from_path(
            filepath, mode='rb', **kwargs)

    def dump_to_str(self, obj, **kwargs):
        kwargs.setdefault('protocol', 2)
        return pickle.dumps(obj, **kwargs)

    def dump_to_fileobj(self, obj, file, **kwargs):
        kwargs.setdefault('protocol', 2)
        pickle.dump(obj, file, **kwargs)

    def dump_to_path(self, obj, filepath, **kwargs):
        super(PickleHandler, self).dump_to_path(
            obj, filepath, mode='wb', **kwargs)

读取文件返回列表和字典  \
使用mmengine的list_from_file和dict_from_file接口 \
参考：https://mmengine.readthedocs.io/zh_CN/latest/advanced_tutorials/fileio.html#id5

读取和保存权重文件

In [None]:
#pytorch风格，从不同的存储端需要使用不同的接口
import torch

filepath1 = '/path/of/your/checkpoint1.pth'
filepath2 = 'http://path/of/your/checkpoint3.pth'

# 从本地磁盘读取权重文件
checkpoint = torch.load(filepath1)
# 保存权重文件到本地磁盘
torch.save(checkpoint, filepath1)

# 从网络远端读取权重文件
checkpoint = torch.utils.model_zoo.load_url(filepath2)

In [None]:
#mmengine风格，统一接口
from mmengine import load_checkpoint, save_checkpoint

filepath1 = '/path/of/your/checkpoint1.pth'
filepath2 = 's3://bucket-name/path/of/your/checkpoint1.pth'
filepath3 = 'http://path/of/your/checkpoint3.pth'

# 从本地磁盘读取权重文件
checkpoint = load_checkpoint(filepath1)
# 保存权重文件到本地磁盘
save_checkpoint(checkpoint, filepath1)

# 从 s3 读取权重文件
checkpoint = load_checkpoint(filepath2)
# 保存权重文件到 s3
save_checkpoint(checkpoint, filepath2)

# 从网络远端读取权重文件
checkpoint = load_checkpoint(filepath3)