In [1]:
import torch
import sys
import torch.nn.functional as f

sys.path.append('..')

from torch.utils.data import DataLoader
from utils.data_util import *
from utils.model_util import *
from utils.train_util import *

torch.set_printoptions(precision=2,
                       threshold=1000,
                       edgeitems=5,
                       linewidth=1000,
                       sci_mode=False)
# 是否使用显卡加速
if torch.cuda.is_available():
    device = 'cuda'
    if torch.backends.cudnn.is_available():
        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True
elif torch.backends.mps.is_available():
    device = 'mps'
else:
    device = 'cpu'
print(device)

cuda


In [2]:
train_dataset, test_dataset, c, h, w = get_dataset()

batch_size = 32
DataSplit = SplitData(train_dataset)
[teacher_dataset1, teacher_dataset2, teacher_dataset3, student_dataset, distill_dataset], client_main_target = DataSplit.all_non_iid(
    num_client=5,
    num_client_data=1000,
    proportion=0.6)
print(client_main_target)
num_target = DataSplit.num_target
teacher_dataloader1 = DataLoader(
    dataset=teacher_dataset1,
    batch_size=batch_size,
    shuffle=True)
teacher_dataloader2 = DataLoader(
    dataset=teacher_dataset2,
    batch_size=batch_size,
    shuffle=True)
teacher_dataloader3 = DataLoader(
    dataset=teacher_dataset3,
    batch_size=batch_size,
    shuffle=True)
student_dataloader = DataLoader(
    dataset=student_dataset,
    batch_size=batch_size,
    shuffle=True)
distill_dataloader = DataLoader(
    dataset=distill_dataset,
    batch_size=batch_size,
    shuffle=True)
test_dataloader = DataLoader(
    dataset=test_dataset,
    batch_size=200,
    shuffle=True)

epochs_teacher = 10
epochs_teacher_p = 30

[8, 0, 1, 2, 7]


In [3]:
initial_teacher_model1 = CNN(h, w, c, num_target)
acc_teacher = []
for epoch in range(epochs_teacher):
    teacher_model1 = train_model(
        model=initial_teacher_model1,
        dataloader=teacher_dataloader1,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model1,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

acc_teacher = []
for epoch in range(epochs_teacher_p):
    teacher_model1 = train_model(
        model=teacher_model1,
        dataloader=distill_dataloader,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model1,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10)]
tensor(0.10)
[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.17), tensor(0.36), tensor(0.54), tensor(0.68), tensor(0.74), tensor(0.74), tensor(0.79), tensor(0.75), tensor(0.81), tensor(0.83), tensor(0.84), tensor(0.84), tensor(0.82), tensor(0.81), tensor(0.86), tensor(0.86), tensor(0.80), tensor(0.86), tensor(0.84), tensor(0.84), tensor(0.84), tensor(0.85), tensor(0.87), tensor(0.87), tensor(0.86), tensor(0.85), tensor(0.87)]
tensor(0.87)


In [4]:
initial_teacher_model2 = CNN(h, w, c, num_target)
acc_teacher = []
for epoch in range(epochs_teacher):
    teacher_model2 = train_model(
        model=initial_teacher_model2,
        dataloader=teacher_dataloader2,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model2,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

acc_teacher = []
for epoch in range(epochs_teacher_p):
    teacher_model2 = train_model(
        model=teacher_model2,
        dataloader=distill_dataloader,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model2,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10)]
tensor(0.10)
[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.16), tensor(0.39), tensor(0.63), tensor(0.65), tensor(0.75), tensor(0.76), tensor(0.72), tensor(0.80), tensor(0.82), tensor(0.81), tensor(0.82), tensor(0.83), tensor(0.84), tensor(0.84), tensor(0.84), tensor(0.85), tensor(0.85), tensor(0.85), tensor(0.85), tensor(0.85), tensor(0.84), tensor(0.85), tensor(0.87), tensor(0.87), tensor(0.86), tensor(0.87), tensor(0.87)]
tensor(0.87)


In [5]:
initial_teacher_model3 = CNN(h, w, c, num_target)
acc_teacher = []
for epoch in range(epochs_teacher):
    teacher_model3 = train_model(
        model=initial_teacher_model3,
        dataloader=teacher_dataloader3,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model3,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

acc_teacher = []
for epoch in range(epochs_teacher_p):
    teacher_model3 = train_model(
        model=teacher_model3,
        dataloader=distill_dataloader,
        device=device)
    acc_teacher.append(eval_model(
        model=teacher_model3,
        dataloader=test_dataloader,
        device=device))
print(acc_teacher)
print(acc_teacher[-1])

[tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11), tensor(0.11)]
tensor(0.11)
[tensor(0.10), tensor(0.10), tensor(0.18), tensor(0.50), tensor(0.61), tensor(0.65), tensor(0.76), tensor(0.76), tensor(0.79), tensor(0.78), tensor(0.84), tensor(0.84), tensor(0.82), tensor(0.83), tensor(0.85), tensor(0.81), tensor(0.82), tensor(0.85), tensor(0.84), tensor(0.85), tensor(0.84), tensor(0.87), tensor(0.85), tensor(0.86), tensor(0.85), tensor(0.85), tensor(0.87), tensor(0.84), tensor(0.85), tensor(0.85)]
tensor(0.85)


In [6]:
initial_student_model = CNN(h, w, c, num_target)
epochs_student = 40
acc_student = []
for epoch in range(epochs_student):
    student_model = train_model(
        model=initial_student_model,
        dataloader=student_dataloader,
        device=device)
    acc_student.append(eval_model(
        model=student_model,
        dataloader=test_dataloader,
        device=device))
print(acc_student)
print(acc_student[-1])

[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10)]
tensor(0.10)


In [7]:
# epochs_student = 20
acc_student = []
for epoch in range(epochs_teacher):
    student_model = train_model(
        model=initial_student_model,
        dataloader=student_dataloader,
        device=device)
    acc_student.append(eval_model(
        model=student_model,
        dataloader=test_dataloader,
        device=device))
print(acc_student)
print(acc_student[-1])

acc_student = []
for epoch in range(epochs_teacher_p):
    student_model = train_model(
        model=student_model,
        dataloader=distill_dataloader,
        device=device)
    acc_student.append(eval_model(
        model=student_model,
        dataloader=test_dataloader,
        device=device))
print(acc_student)
print(acc_student[-1])

[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10)]
tensor(0.10)
[tensor(0.10), tensor(0.10), tensor(0.21), tensor(0.54), tensor(0.66), tensor(0.70), tensor(0.75), tensor(0.79), tensor(0.81), tensor(0.82), tensor(0.82), tensor(0.85), tensor(0.83), tensor(0.83), tensor(0.84), tensor(0.86), tensor(0.85), tensor(0.86), tensor(0.85), tensor(0.85), tensor(0.87), tensor(0.85), tensor(0.86), tensor(0.84), tensor(0.86), tensor(0.86), tensor(0.86), tensor(0.86), tensor(0.86), tensor(0.87)]
tensor(0.87)


In [8]:
epochs_distill = 30
acc_student = []
for epoch in range(epochs_distill):
    student_model = train_model(
        model=initial_student_model,
        dataloader=student_dataloader,
        device=device)
    acc_student.append(eval_model(
        model=student_model,
        dataloader=test_dataloader,
        device=device))
print(acc_student)
print(acc_student[-1])

[tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10), tensor(0.10)]
tensor(0.10)


In [9]:
teacher_model1.eval()
teacher_model2.eval()
teacher_model3.eval()
acc_student = []

for epoch in range(epochs_distill):
    student_model = train_model_disti_weighted(
        model=student_model,
        neighbor_server_model=[teacher_model1, teacher_model2, teacher_model3],
        weight=torch.tensor([1/3, 1/3, 1/3]),
        dataloader=student_dataloader,
        alpha=0.1,
        T=2,
        device=device)
    acc_student.append(eval_model(
        model=student_model,
        dataloader=test_dataloader,
        device=device))
print(acc_student)
print(acc_student[-1])

[tensor(0.10), tensor(0.33), tensor(0.56), tensor(0.68), tensor(0.73), tensor(0.74), tensor(0.75), tensor(0.80), tensor(0.79), tensor(0.80), tensor(0.82), tensor(0.83), tensor(0.83), tensor(0.84), tensor(0.85), tensor(0.85), tensor(0.85), tensor(0.85), tensor(0.87), tensor(0.87), tensor(0.86), tensor(0.87), tensor(0.87), tensor(0.87), tensor(0.87), tensor(0.86), tensor(0.88), tensor(0.87), tensor(0.88), tensor(0.87)]
tensor(0.87)
