In [2]:
import logging
import torch
import sys
import os
from tqdm import tqdm

# 添加环境
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../MyExpr")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../FedML")))
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), "../../")))

print(sys.path)

# 查看GPU
print(torch.cuda.is_available())
for i in range(torch.cuda.device_count()):
    print("GPU[{:d}]: {:s}".format(i, torch.cuda.get_device_name(i)))

# 选择GPU
os.environ['CUDA_VISIBLE_DEVICES'] =  "0"
print(torch.cuda.device_count())

['/home/guest/Fed_Expr', '/home/guest/Fed_Expr/FedML', '/home/guest/Fed_Expr/MyExpr', '/home/guest/Fed_Expr/MyExpr/ml', '/home/guest/miniconda/envs/fedml/lib/python37.zip', '/home/guest/miniconda/envs/fedml/lib/python3.7', '/home/guest/miniconda/envs/fedml/lib/python3.7/lib-dynload', '', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages', '/home/guest/miniconda/envs/fedml/lib/python3.7/site-packages/IPython/extensions', '/home/guest/.ipython']
True
GPU[0]: GeForce RTX 2080 Ti
GPU[1]: GeForce RTX 2080 Ti
GPU[2]: GeForce RTX 2080 Ti
GPU[3]: GeForce RTX 2080 Ti
4


In [13]:
from MyExpr.dfl.Args import add_args
from MyExpr.data import Dataset
from torch.utils.data import DataLoader

parser = add_args()

# args = parser.parse_args()
args = parser.parse_known_args()[0]
args.epochs = 200

In [None]:
train_set, test_set = Dataset(args)


train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)
test_loader = DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers)

In [None]:
from MyExpr.dfl.model.resnet import resnet18
from MyExpr.dfl.model.resnet import resnet34
from MyExpr.dfl.model.resnet import resnet50
from MyExpr.dfl.model.resnet import resnet101
from MyExpr.dfl.model.resnet import resnet152
from MyExpr.dfl.model.cnn import BaseConvNet

import torch.nn as nn

# 加入测试模型
model_list = [resnet18(num_classes=10), resnet34(num_classes=10), 
              resnet50(num_classes=10), resnet101(num_classes=10), 
              resnet152(num_classes=10)]

criterion_CE = nn.CrossEntropyLoss()

In [4]:
from MyExpr.dfl.model.resnet import resnet18
from MyExpr.dfl.model.resnet import resnet34
from MyExpr.dfl.model.resnet import resnet50
from MyExpr.dfl.model.resnet import resnet101
from MyExpr.dfl.model.resnet import resnet152
from MyExpr.dfl.model.cnn import BaseConvNet

import torch.nn as nn

# 加入测试模型
model_list = [resnet18(num_classes=10), resnet34(num_classes=10), 
              resnet50(num_classes=10), resnet101(num_classes=10), 
              resnet152(num_classes=10)]

criterion_CE = nn.CrossEntropyLoss()

In [7]:
import wandb
import time

def run(model, model_name):
    name = "{:s}-lr{:3f}-bs{:d}".format(model_name, args.lr, args.batch_size)
    print(name)

    wandb.init(project="classic-ml",
               entity="kyriegyj",
               name=name,
               config=args)
    
    total_train_iteration = 0
    
    model.to(args.device)
    optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr)
    for epoch in range(args.epochs):
        # train
        start_time =  time.perf_counter()
        total_loss = 0
        total_correct = 0
        for iteration, (train_X, train_Y) in enumerate(train_loader):
            optimizer.zero_grad()
            train_X, train_Y = train_X.to(args.device), train_Y.to(args.device)
            outputs = model(train_X)
            loss = criterion_CE(outputs, train_Y)

            pred = outputs.argmax(dim=1)
            correct = pred.eq(train_Y.view_as(pred)).sum()

            loss.backward()
            optimizer.step()

            if "cuda" in args.device:
                loss = loss.cpu()

            loss = loss.detach().numpy()
            acc = (correct / args.batch_size)
            # wandb.log(step=total_train_iteration, data={"loss":loss, "acc:":acc})

            total_loss += loss
            total_correct += correct
            total_train_iteration += 1

        total_acc = (total_correct / (len(train_loader) * args.batch_size))
        wandb.log(step=epoch, data={"total_loss":total_loss, "total_acc":total_acc})
        end_time =  time.perf_counter()
        print("epoch[{:d}] spends {:f}s".format(epoch, (end_time - start_time)))

        # test
        total_test_loss = 0
        total_test_correct = 0
        with torch.no_grad():
            for iteration, (test_X, test_Y) in enumerate(test_loader):
                test_X, test_Y = test_X.to(args.device), test_Y.to(args.device)
                outputs = model(test_X)
                loss = criterion_CE(outputs, test_Y)
                pred = outputs.argmax(dim=1)

                if "cuda" in args.device:
                    loss = loss.cpu()

                loss = loss.detach().numpy()
                correct = pred.eq(test_Y.view_as(pred)).sum()

                total_test_loss += loss
                total_test_correct += correct

            total_test_acc = (total_test_correct / (len(test_loader) * args.batch_size))
            wandb.log(step=epoch, data={"total_test_loss":total_test_loss, "total_test_acc":total_test_acc})

    wandb.finish()
    

In [None]:
# for model in model_list:
#     run(model)
model = BaseConvNet(num_classes=10)
run(model, "BaseConvNet")

BaseConvNet-lr0.010000-bs64


epoch[0] spends 7.566878s
epoch[1] spends 7.935749s
epoch[2] spends 8.304039s
epoch[3] spends 8.388716s
epoch[4] spends 8.119140s
epoch[5] spends 8.206184s
epoch[6] spends 8.258157s
epoch[7] spends 8.333662s
epoch[8] spends 7.504549s
epoch[9] spends 8.302049s
epoch[10] spends 7.650483s
epoch[11] spends 8.691345s
epoch[12] spends 7.783510s
epoch[13] spends 8.107292s
epoch[14] spends 7.846443s
epoch[15] spends 8.022875s
epoch[16] spends 8.329089s
epoch[17] spends 7.657982s
epoch[18] spends 7.830966s
epoch[19] spends 8.289500s
epoch[20] spends 8.158520s
epoch[21] spends 8.573558s
epoch[22] spends 8.149045s
epoch[23] spends 8.451627s
epoch[24] spends 7.956380s
epoch[25] spends 8.195102s
epoch[26] spends 7.753339s
epoch[27] spends 7.985140s
epoch[28] spends 7.665067s
epoch[29] spends 7.289629s
epoch[30] spends 7.695361s
epoch[31] spends 7.816504s
epoch[32] spends 7.824942s
epoch[33] spends 7.722792s
epoch[34] spends 7.743649s
epoch[35] spends 7.805339s
epoch[36] spends 7.360278s
epoch[37] s