Skip to content

I-am-a-rookie/transfer_learning

main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Code

Latest commit

 

Git stats

Files

Permalink
Failed to load latest commit information.
Type
Name
Latest commit message
Commit time
 
 
 
 
 
 
 
 
 
 
 
 
 
 

项目详解

get_data.py (获取数据)

目前支持 MNIST, Fashion MNIST, CIFAR 10 和 CIFAR 100 数据集. 可以在```get_data.py``下自行替换成自己需要的数据集.

传入数据的格式为:

data_loader = {"train": train_loader, "valid": test_loader}

get_model (获取模型)

目前支持:

  • resnet18
  • resnet34
  • resnet50
  • resnet101
  • resnet152
  • alexnet
  • squeezenet
  • vgg11
  • vgg13
  • vgg16
  • vgg19

替换模型的方法:

python main.py --model_name "模型名称"

例如, 使用 vgg 13:

python main.py --model_name vgg13

例如, 使用 resnet 152:

python main.py --model_name resnet152

参数详解

必填参数:

  • model_name: 模型名称, 类型为 string
  • num_classes: 输出类别数, 类型为 int (例如 MNIST 是 10 分类, CIFAR 100 是 100 分类)

重要参数:

  • data_name: 数据名称, 类型为 string, 默认为 CIFAR10
  • data_gray: 是否为灰度图, 类型为 boolean, 默认为 False
  • num_epochs: 迭代次数, 类型为 int, 默认为 20
  • batch_size: 一个批次的样本数目, 默认为 512

可选参数 (不建议修改):

  • feature_exact: 是否冻层, 类型为 boolean, 默认为 False
  • use_pretrained: 是否使用预训练权重, 类型为 boolean, 默认为 True
  • pretrained_model_path: 预训练权重, 类型为 string, 默认为 pretrained_model/
  • model_save_path: 模型保存路径, 类型为 string, 默认为 "checkpoint/"
  • visualize: 模型可视化, 类型为 boolean, 默认为 True

使用说明

首先我们需要cd到文件路径, 例如:

cd C:\Users\Windows\Desktop\Project\transfer_learning-main

训练 MNIST

使用 vgg13 训练 MNIST 数据集:

python main.py --data_name MNIST --data_gray True --model_name vgg13 --num_classes 10 --batch_size 2048

训练 Fashion MNIST

使用 vgg19 训练 Fashion MNIST 数据集:

python main.py --data_name FashionMNIST --data_gray True --model_name vgg19 --num_classes 10 --batch_size 2048

训练 CIFAR 10

使用 resnet18 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR10 --model_name resnet18 --num_classes 10 --batch_size 2048

训练 CIFAR 100

使用 resnet152 训练 CIFAR 10 数据集:

python main.py --data_name CIFAR100 --model_name resnet152 --num_classes 100 --batch_size 2048

训练自己的数据

python main.py --data_name other --model_name ? --num_classes ? --batch_size ? --epochs ?

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages