In [None]:
import sys
import time
sys.path.append('')
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
from tqdm.auto import tqdm
import os
from funcs import get_student_teacher

## 训练模型

In [None]:
import argparse
parse = argparse.ArgumentParser()
parse.add_argument('--dataset_path', type=str, default='../StealingVerification-main/data/cifar10/', help="dataset path")
parse.add_argument('--dataset', type=str, default='CIFAR10', choices=['CIFAR10','self'], help="dataset的名称")
parse.add_argument('--mode', type=str, default='teacher', help="训练方式（暂时不要修改）")
parse.add_argument('--model_id', type=str, default='0', help="模型名称，关系到保存地址")
parse.add_argument('--batch_size', type=int, default=128)
parse.add_argument('--epoch', type=int, default=135)
parse.add_argument('--device', type=str, default='cuda')
parse.add_argument('--learning_rate', type=float, default=0.01)
parse.add_argument('--save_path', type=str, default='./model_train/', help="保存的大路径")
parse.add_argument('--model_root', type=str, default='', help="模型地址")
parse.add_argument('--num_classes', type=int, default=10)
parse.add_argument('--normalize', type=int, default=1)
args = parse.parse_args(args=['--dataset_path','../StealingVerification-main/data/cifar10_seurat_10%/',
                              '--dataset', 'CIFAR10',
                              '--mode', 'teacher',
                              '--model_id', 'test',
                              '--batch_size', '128',
                              '--epoch', '135',
                              '--device', 'cuda',
                              '--learning_rate', '0.01',
                              '--save_path', './trained/',
                              '--model_root', '',
                              '--num_classes', '10',
                              '--normalize', '1'])

print(args)

In [None]:
def load_dataset(args):
    if args.dataset == "CIFAR10":
        transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),   #先四周填充0，在吧图像随机裁剪成32*32
        transforms.RandomHorizontalFlip(),      #图像一半的概率翻转，一半的概率不翻转
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
            #R,G,B每层的归一化用到的均值和方差
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = torchvision.datasets.CIFAR10(root='../data/',
                                                train=True,
                                                transform=transform_train,
                                                download=True)

        test_dataset = torchvision.datasets.CIFAR10(root='../data/',
                                                train=False,
                                                transform=transform_test)
    if args.dataset == "self":
        # 当数据集为私有时，在此处设置transform
        train_dataset = torchvision.datasets.ImageFolder(root=args.dataset_path+'train',
                                    transform=transforms.Compose([
                                        transforms.RandomCrop(32, padding=4),
                                        transforms.RandomHorizontalFlip(),
                                        transforms.ToTensor(),
                                       transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                    ]))
        test_dataset = torchvision.datasets.ImageFolder(root=args.dataset_path+'test',
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                    ]))

    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                               num_workers=0, drop_last=True)
    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size, shuffle=True,
                                              num_workers=0)
    return train_loader, test_loader

In [None]:
def get_model(args):
    # 获取模型，因为自己写的太差，所以调用了train.py的get_student_teacher
    # 当dataset为cifar10的时候model为WideResNet
    # 当自己为完整保存时，应该直接使用torch.load()加载
    if not args.model_root == "":
        model,_ = get_student_teacher(args)
        model.load_state_dict(torch.load(args.model_root), strict=False)
        return model
    else:
        print("未设置模型地址")
        return None

In [None]:
def update_lr(optimizer, lr):
# 更新学习率
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

def train(model, args):
    # 训练模型，这边代码写的不太行，ResNet训练Cifar10只有89%左右正确率
    save_path = args.save_path + args.dataset + '/' + args.model_id + '/'
    train_loader, test_loader = load_dataset(args)
    total_step = len(train_loader)
    lr = args.learning_rate
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=0.9, weight_decay=5e-4)
    with tqdm(range(args.epoch), leave=True) as pbar:
        running_loss = 0.0
        for epochs in pbar:
            it_pbar = tqdm(train_loader, leave=False)
            for i ,(images, labels) in enumerate(it_pbar):
                images = Variable(images)
                labels = Variable(labels)
                images = images.to(args.device)
                labels = labels.to(args.device)

                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

                if (epochs+1) % 20 == 0:
                    lr /= 3
                    update_lr(optimizer, lr)
                it_pbar.set_description('loss {}'.format(loss.item()))
                running_loss = loss
            it_pbar.close()
            model.eval()
            accuracy = test(model, test_loader, args)
            model.train()
            pbar.set_description('epoch {} loss {} accuracy {} %'.format(epochs, running_loss.item(),
                                                                      accuracy))
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    torch.save(model.state_dict(), save_path + 'final.pt')
    torch.save(model, save_path + 'model.pt')
    model.eval()


def test(model, test_loader, args):
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = Variable(images)
            labels = Variable(labels)
            images = images.to(args.device)
            labels = labels.to(args.device)
            outputs = model(images)
            _,predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        accuracy = 100 * correct / total
        print("Accuracy: {} %".format(accuracy))
    return accuracy


In [None]:
model, _ = get_student_teacher(args)
train(model, args)

In [None]:
model = get_model(args)
test(model, args)