In [1]:
import os
import sys
# 将项目根目录（MyHub）的父目录添加到系统路径中
# 这使得 'MyHub' 本身可以被当作一个包来导入
project_root = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

import json
import time
import argparse
import datetime
import numpy as np
from pathlib import Path
import timm.optim.optim_factory as optim_factory
from torch.utils.tensorboard.writer import SummaryWriter

from MyHub.core.registry import DATASETS, MODELS, POSTFUNCS, TRANSFORMS, EVALUATORS, build_from_registry
from MyHub.training_scripts.utils import misc
from MyHub.training_scripts.utils.yaml import load_yaml_config,split_config, add_attr
from MyHub.training_scripts.trainer import train_one_epoch
from MyHub.training_scripts.tester import test_one_epoch

  from .autonotebook import tqdm as notebook_tqdm


2.1.2
Current Torch version: 2.1


In [2]:
parser = argparse.ArgumentParser(description='Training Script')
parser.add_argument('--config', default='config/resnet_train.yaml', help='path to config file', type=str)

args = parser.parse_args(['--config', '/home/chengxiaozhen/Test/ForensicHub/MyHub/config/resnet_train.yaml'])

config = load_yaml_config(args.config)
args, model_args, train_dataset_args, test_dataset_args, transform_args, evaluator_args = split_config(config)
add_attr(args,output_dir=args.log_dir)
add_attr(args,if_not_amp=not args.use_amp)

if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir, exist_ok=True)

In [6]:
transform = build_from_registry(TRANSFORMS, transform_args)
train_transform = transform.get_train_transform()
test_transform = transform.get_test_transform()
post_transform = transform.get_post_transform()
# print("Train transform: ", train_transform)
# print("Test transform: ", test_transform)
# print("Post transform: ", post_transform)

[Lazy import] 从MyHub.common.transforms.CrossTransform 加载 CrossTransform
[build_from_registry] 创建模型 'CrossTransform' 参数: {}


In [4]:
 # get post function (if have)
post_function_name = f"{model_args['name']}_post_func".lower()
if model_args.get('post_func_name') is not None:
    post_function_name = f"{model_args['post_func_name']}_post_func".lower()
print(f"Post function check: {post_function_name}")
if POSTFUNCS.has(post_function_name):
    post_function = POSTFUNCS.get(post_function_name)
else:
    post_function = None
print(post_function)

Post function check: resnet101_post_func
None


In [7]:
train_dataset_args["init_config"].update({
        "post_funcs": post_function,
        "common_transform": train_transform,
        "post_transform": post_transform
    })
train_dataset = build_from_registry(DATASETS, train_dataset_args)

[Lazy import] 从MyHub.common.datasets.CrossDataset 加载 CrossDataset
[build_from_registry] 创建模型 'CrossDataset' 参数: {'dataset_config': [{'name': 'LabelDataset', 'pic_nums': 12641, 'init_config': {'image_size': 256, 'path': '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/CASIAv2.json'}}, {'name': 'LabelDataset', 'pic_nums': 12641, 'init_config': {'image_size': 256, 'path': ['/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_ADM.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_BigGAN.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_glide.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_Midjourney.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_stable_diffusion_v_1_4.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_stable_diffusion_v_1_5.json', '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_VQDM.json', '/hom

In [14]:
for dataset in train_dataset.datasets:
    print(dataset)

LabelDataset from /home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/CASIAv2.json
Total samples: 12614
Label 0 samples (real): 7491
Label 1 samples (fake): 5123
LabelDataset from /home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_ADM.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_BigGAN.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_glide.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_Midjourney.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_stable_diffusion_v_1_4.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_stable_diffusion_v_1_5.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_VQDM.json,/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/GenImage_wukong.json
Total samples: 2581167
Label 0 samples (real): 1281167
Label 1 samples (fake): 1300000
LabelDataset from /home/chengxiaozhen/Test/ForensicHub/M

In [None]:
config = train_dataset_args['init_config']
print(type(config))

<class 'dict'>


In [24]:
labeldataset = build_from_registry(DATASETS, config['dataset_config'][0])

[build_from_registry] 创建模型 'LabelDataset' 参数: {'image_size': 256, 'path': '/home/chengxiaozhen/Test/ForensicHub/MyHub/data/train_data/CASIAv2.json', 'common_transform': Compose([
  HorizontalFlip(always_apply=False, p=0.5),
  VerticalFlip(always_apply=False, p=0.5),
  RandomBrightnessContrast(always_apply=False, p=1, brightness_limit=(-0.1, 0.1), contrast_limit=(-0.1, 0.1), brightness_by_max=True),
  ImageCompression(always_apply=False, p=0.2, quality_lower=70, quality_upper=100, compression_type=0),
  RandomRotate90(always_apply=False, p=0.5),
  GaussianBlur(always_apply=False, p=0.2, blur_limit=(3, 7), sigma_limit=(0, 0)),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}), 'post_transform': Compose([
  Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255.0),
  ToTensorV2(always_apply=True, p=1.0, transpose_mask=True),
], p=1.0, bbox_params=None, keypoint_params=None, additional_targets={}), 'post_funcs'

In [26]:
labeldataset.image_size

256