In [2]:
import torch
import torchvision

from utils.data_util import *
from utils.train_util import *
from torch.utils.data import DataLoader
from copy import deepcopy

torch.set_printoptions(
    precision=2,
    threshold=1000,
    edgeitems=5,
    linewidth=1000,
    sci_mode=False)


In [3]:
train_dataset, test_dataset, c, h, w = get_dataset('cifar10')
net1 = CNN(h, w, c, num_classes=10)
net2 = LeNet5(h, w, c, num_classes=10)
net3 = torchvision.models.resnet18(weights=None, num_classes=10)
model1 = net1.cuda()
model2 = net2.cuda()
model3 = net3.cuda()
trainloader = DataLoader(
    dataset=train_dataset,
    batch_size=160,
    shuffle=True,
    pin_memory=True,
    num_workers=8)
testloader = DataLoader(
    dataset=test_dataset,
    batch_size=160,
    shuffle=True,
    pin_memory=True,
    num_workers=8)

Files already downloaded and verified


In [4]:
loss_func = torch.nn.CrossEntropyLoss().cuda()
model1.train()
model2.train()
model3.train()
optimizer1 = torch.optim.Adam(model1.parameters())
optimizer2 = torch.optim.Adam(model2.parameters())
optimizer3 = torch.optim.Adam(model3.parameters())
for i in range(10):
    loss_ = []
    for data, target in trainloader:
        data_device = data.cuda()
        optimizer1.zero_grad()
        optimizer2.zero_grad()
        optimizer3.zero_grad()
        output1 = model1(data_device)
        output2 = model2(data_device)
        output3 = model3(data_device)
        loss1 = loss_func(output1, target.cuda())
        loss2 = loss_func(output2, target.cuda())
        loss3 = loss_func(output3, target.cuda())
        loss1.backward()
        loss2.backward()
        loss3.backward()
        optimizer1.step()
        optimizer2.step()
        optimizer3.step()

In [None]:
args = get_args()
args.device = 'cuda'

server_client = [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
neighbor_server = [[1], [2], [0]]
all_client = [i for i in range(args.num_all_client)]
all_server = [i for i in range(args.num_all_server)]
num_server_client = args.num_all_client // args.num_all_server

train_dataset_o, test_dataset_o, c, h, w = get_dataset(args.dataset)
target_list = {0: [0, 1, 2], 1: [3, 4, 5], 2: [6, 7, 8, 9]}
num_target, train_dataloader, validate_dataloader = split_dataset(
    train_dataset_o, target_list, args)
[public_dataset, test_dataset] = split_parts_random(
    test_dataset_o, [args.num_public_data, int(len(test_dataset_o)) - args.num_public_data])
public_dataloader = DataLoader(
    dataset=public_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    pin_memory=True,
    num_workers=args.num_workers)
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=args.batch_size,
    pin_memory=True,
    num_workers=args.num_workers)

client_model = [list_same_term(3, model1), list_same_term(3, model2), list_same_term(3, model3)]
server_model = [model1, model2, model3]
client_accuracy = list_same_term(args.num_all_client)
validate_accuracy = list_same_term(args.num_all_client)
client_loss = deepcopy(client_accuracy)
weight_server = list_same_term(args.num_all_server, 1/args.num_all_server)
weight_list = list_same_term(args.num_all_server, weight_server)

# %% 模型训练
keys = ['server_model',
        'train_dataloader',
        'test_dataloader',
        'validate_dataloader',
        'public_dataloader',
        'num_target',
        'client_accuracy',
        'client_loss',
        'validate_accuracy',
        'weight_list',
        'weight_server',
        'server_client',
        'all_server',
        'client_model',
        'target_list',
        'public_dataset']
values = [server_model,
          train_dataloader,
          test_dataloader,
          validate_dataloader,
          public_dataloader,
          num_target,
          client_accuracy,
          client_loss,
          validate_accuracy,
          weight_list,
          weight_server,
          server_client,
          all_server,
          client_model,
          target_list,
          public_dataset]
args_train = dict(zip(keys, values))

In [17]:
server_model = [model1.eval(), model2.eval(), model3.eval()]

def distill_public(args, args_train):
    server_model = {}
    for server in args_train['all_server']:
        server_model[server] = args_train['server_model'][server].eval()
    dataset_ = []
    target_logits = {}
    for i, (data, target) in enumerate(args_train['public_dataset']):
        data_device = data.to(args.device)
        data_ = torch.unsqueeze(data_device, dim=0)
        logits = torch.zeros([1, 10]).to(args.device)
        num_target_servers = 0
        for server in args_train['all_server']:
            if target in args_train['target_list'][server]:
                logits += server_model[server](data_)
                num_target_servers += 1
        logits = logits / num_target_servers
        dataset_.append((data, i))
        target_logits[i] = (target, logits)
    return dataset_, target_logits

{0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}
{0: [0], 1: [0], 2: [0], 3: [1], 4: [1], 5: [1], 6: [2], 7: [2], 8: [2], 9: [2]}
10000


In [28]:
dataset_, target_logits = distill_public(public_dataset, target_list)

<class 'torch.utils.data.dataloader.DataLoader'>
<class 'torch.Tensor'> <class 'torch.Tensor'> tensor([3150, 3305, 6479, 9536,  539, 6011,  715, 1457, 3301, 3087, 9503, 4522, 6461, 5851, 3707, 5580, 2413, 1108, 9043, 8149, 4837, 8376, 6903, 7396, 8072, 2541, 2226, 4346, 1973, 3891, 1862, 1746, 9097, 3015,  595, 5264, 6971, 5999, 4055, 3208,  375, 9180, 2679, 3298, 6545, 2891, 4135, 5979, 2561, 8473, 1986, 8949,  931, 9679, 8281, 2673, 7734, 8743, 7407, 5518, 1332, 7029, 8797, 7263, 1389, 7220, 9578, 3070, 2108, 2700, 9059, 3363, 4888, 6527, 3088, 8484, 5172, 1992, 3299, 2689, 8205, 9990, 7633, 2020, 7883, 6617, 6935, 8764,  149, 6378, 2180, 9068, 4510, 2457, 9236, 5973,  637, 9831, 9695, 2903, 5640, 9761, 3772,  934, 4580, 6434, 8115, 5503, 5435,  391, 6582, 3399, 7231, 6953, 8911, 7067, 8011, 5299, 5892, 2398, 1946, 9410, 1681, 9505, 1200, 4186, 7476, 5965, 6702, 3034, 7531,  970, 2396, 2471, 7933, 4382, 5199, 5094, 3811, 9670, 1116, 3413, 9252, 4009, 4630, 1315, 3187, 3218, 9074, 981

In [30]:
a = torch.tensor([])
print(a)
b = torch.cat((a, torch.tensor([1, 2, 3])))
print(b.shape)
print(output1.shape, target.shape)

tensor([])
torch.Size([3])
torch.Size([80, 10]) torch.Size([160])


In [37]:
a = torch.zeros([2])
b = torch.ones([1])
a[0] = b
a = a.type(dtype=torch.int)
print(a)

tensor([1, 0], dtype=torch.int32)


In [42]:
model = deepcopy(model1).train()
optimizer = torch.optim.Adam(model.parameters())
for i in range(10):
    loss_ = []
    for data, target in trainloader:
        data_device = data.cuda()
        optimizer.zero_grad()
        output = model(data_device)
        loss = loss_func(output, target.cuda())
        loss.backward()
        optimizer.step()
        print(output.shape, output.dtype)
        print(data.shape, data.dtype)
        print(target.shape, target.dtype)
        break
    break

torch.Size([160, 10]) torch.float32
torch.Size([160, 3, 32, 32]) torch.float32
torch.Size([160]) torch.int64
