In [1]:
import numpy as np
import pandas as pd
from pprint import pprint

from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
# from torch.autograd import grad
import torchvision
from torchvision import models, datasets, transforms
from torch.utils.data import DataLoader, Dataset
import time
import scipy.stats

import random
import math
from torch.utils.tensorboard import SummaryWriter

import copy
from typing import Any, Callable, Optional, Tuple

from resnet import resnet18
from mobilenet import mobilenet_v2

def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


print(torch.__version__, torchvision.__version__)

1.12.1 0.13.1


In [2]:
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
print("Running on %s" % device)

Running on cuda


In [3]:
description = """
此脚本产生质量推理所需的每轮训练准确度数据，实验设定包括：
    1.无噪音的，iid设置下的训练数据
    2.无噪音的，非iid设置下的训练数据
    3.有噪音的，训练数据
iid指各客户端上的数据都是从整体数据集上均匀采样得出的，且数量相同
非iid指各客户端上数据从整体数据集上均匀采样得出，但数量不相同
在有噪声的设置下，各客户端数据量均相同
"""
print(description)


此脚本产生质量推理所需的每轮训练准确度数据，实验设定包括：
    1.无噪音的，iid设置下的训练数据
    2.无噪音的，非iid设置下的训练数据
    3.有噪音的，训练数据
iid指各客户端上的数据都是从整体数据集上均匀采样得出的，且数量相同
非iid指各客户端上数据从整体数据集上均匀采样得出，但数量不相同
在有噪声的设置下，各客户端数据量均相同



In [4]:
from turtle import forward


class MlpNet(nn.Module):
    def __init__(self,dstName):
        super().__init__()
        if dstName == "MINIST":
            num_in = 28 * 28
            num_hid = 64
            num_out = 10
        else:
            num_in = 32 * 32 * 3
            num_hid = 64
            num_out = 10
        self.body = nn.Sequential(
            nn.Linear(num_in,num_hid),
            nn.ReLU(),
            nn.Linear(num_hid,num_hid),
            nn.ReLU(),
            nn.Linear(num_hid,num_out)
        )
    def forward(self,x):
        x = x.view(x.size(0), -1)
        return self.body(x)

class CnnNet(nn.Module):
    def __init__(self, dstName):
        super().__init__()
        if dstName == 'CIFAR10':
            # input [3, 32, 32]
            self.body = nn.Sequential(
                nn.Conv2d(3, 10, kernel_size=5, padding=1, stride=1), # [10, 32, 32]
                nn.BatchNorm2d(10),
                nn.ReLU(),
                nn.MaxPool2d(2,2,0),# [10, 16, 16]

                nn.Conv2d(10, 20, kernel_size=5, padding=1, stride=1), # [20, 16, 16]
                nn.BatchNorm2d(20),
                nn.ReLU(),
                nn.MaxPool2d(2,2,0),#[20, 6, 6]
            )
            self.fc = nn.Sequential(
                nn.Linear(20*6*6, 84),
                nn.Linear(84, 10)
            )
        if dstName == 'MINIST':
            # input [1, 28, 28]
            self.body = nn.Sequential(
                nn.Conv2d(1, 5, kernel_size=5, padding=1, stride=1), # [5, 28, 28]
                nn.BatchNorm2d(5),
                nn.ReLU(),
                nn.MaxPool2d(2,2,0),# [5, 14, 14]

                nn.Conv2d(5, 10, kernel_size=5, padding=1, stride=1), # [10, 14, 14]
                nn.BatchNorm2d(10),
                nn.ReLU(),
                nn.MaxPool2d(2,2,0),#[10, 7, 7]
            )
            self.fc = nn.Sequential(
                nn.Linear(250, 84),
                nn.Linear(84, 10)
            )
    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        return self.fc(out)

class ResNet18(nn.Module):
    def __init__(self, dstName):
        super().__init__()
        if dstName == 'CIFAR10':
            self.body = resnet18(pretrained=False,n_classes=10,input_channels=3)
        if dstName == 'MINIST':
            self.body = resnet18(pretrained=False,n_classes=10,input_channels=1)
    def forward(self, x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        return out

class MobileNetV2(nn.Module):
    def __init__(self, dstName):
        super().__init__()
        if dstName == 'CIFAR10':
            self.body = mobilenet_v2(pretrained=False,n_class=10,i_channel=3,input_size=32)
        if dstName == 'MINIST':
            self.body = mobilenet_v2(pretrained=False,n_class=10,i_channel=1,input_size=28)
    def forward(self,x):
        out = self.body(x)
        out = out.view(out.size(0), -1)
        return out

class Classifier(nn.Module):
    def __init__(self, netName, dstName):
        super().__init__()
        if netName == 'CNN':
            self.body = CnnNet(dstName)
        if netName == 'MLP':
            self.body = MlpNet(dstName)
        if netName == 'RESNET18':
            self.body = ResNet18(dstName)
        if netName == 'MOBILENETV2':
            self.body = MobileNetV2(dstName)
    def forward(self,x):
        return self.body(x)

In [5]:
def diff(origin_net, update_net):
    grad = {}
    ori_params = origin_net.state_dict()
    new_params = update_net.state_dict()
    for k in new_params:
        grad[k] = new_params[k] - ori_params[k]
    return grad

def update(net, grad):
    params = net.state_dict()
    for k in params:
        params[k] = params[k] + grad[k]
    net.load_state_dict(params, strict=True)
    return net

def weights_init(m):
    if hasattr(m, "weight"):
        m.weight.data.uniform_(-0.5, 0.5)
    if hasattr(m, "bias"):
        m.bias.data.uniform_(-0.5, 0.5)

def aggregate(net, clientGrads, clientN):
    sum_grad = {}
    for k in clientGrads[0]:
        sum = 0
        for cId in range(clientN):
            sum = sum + clientGrads[cId][k]
        sum_mean = sum/clientN
        sum_grad[k] = sum_mean
    return update(net, sum_grad)

In [6]:
def selectParticipants(participantsNumber, clientsNumber, scores, goodRate):
    if len(goodRate) != 3:
        return np.random.choice(clientsNumber, participantsNumber, False).tolist()
    bad = int(participantsNumber * goodRate[0])
    mid = int(participantsNumber * goodRate[1])
    good = int(participantsNumber * goodRate[2])
    badIndexBegin = 0
    midIndexBegin = int(clientsNumber * 0.5)
    goodIndexBegin = int(clientsNumber * 0.7)
    result = np.random.choice(range(goodIndexBegin, clientsNumber), good, False).tolist()
    result += np.random.choice(range(badIndexBegin, midIndexBegin), bad, False).tolist()
    result += np.random.choice(range(midIndexBegin, goodIndexBegin), mid, False).tolist()

    return result

In [7]:
transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.ToTensor()
])

class CustomSubset(torch.utils.data.dataset.Subset):
    def __init__(self,  dataset, indices, id, threshold):
        self.id = id
        self.threshold = threshold
        super().__init__(dataset, indices)
    
    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        img, target = self.dataset[self.indices[index]]
        # Add noise to client dataset
        # see Eq. 4 in Reference paper
        dice = np.random.uniform(0, len(self.indices),1)[0]
        uid = int(np.random.uniform(0,10,1)[0])
        if dice < self.threshold:
            target = uid

        return img, target

In [8]:
def initDst(DST_NAME,CLIENT_N, ISIID=True, NOISE=True):
    dstPath = "./data"
    if DST_NAME == 'MINIST':
        trainDst = datasets.MNIST(dstPath, train=True, download=False, transform=transform)
    elif DST_NAME == 'CIFAR10':
        trainDst = datasets.CIFAR10(dstPath, train=True, download=False, transform=transform)
    else:
        raise RuntimeError('MUST specific a dataset')

    indices = [i for i in range(trainDst.__len__())]
    random.shuffle(indices)
    if ISIID:
        numPerClient = trainDst.__len__() // CLIENT_N
        splitList = [numPerClient] * CLIENT_N
    else:
        total = (1+CLIENT_N)*CLIENT_N / 2
        splitList = [int((i+1)*1.0 / total * len(indices)) for i in range(CLIENT_N)]
    clientTrainDstSet = []
    start_idx = 0
    for c in range(CLIENT_N):
        if NOISE:
            threshold = splitList[c] * (CLIENT_N-1 - c)*1.0/(CLIENT_N-1)
        else:
            threshold = -1
        clientTrainDstSet.append(CustomSubset(trainDst, indices[start_idx:start_idx+splitList[c]], c, int(threshold)))
        start_idx += splitList[c]
    
    test_path = "C:/Users/WenJie Li/cnn_explore/data"
    if DST_NAME == 'MINIST':
        testDst = datasets.MNIST(dstPath, train=False, download=False, transform=transform)
    elif DST_NAME == 'CIFAR10':
        testDst = datasets.CIFAR10(dstPath, train=False, download=False, transform=transform)
    else:
        raise RuntimeError('MUST specific a dataset')
    return clientTrainDstSet, testDst

In [9]:
def updateScoresV4(args):
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']
    expAcc_logs = args['expAcc_logs']
    
    result = {
        'scores':scores,
        'deltaAcc':deltaAcc,
        'expAcc_logs':expAcc_logs
    }

    if round == 0:
        return result
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if acc_logs[round] < acc_logs[round-1]:
        for cId in participants:
            scores[cId][round] -= 1

    return result

def updateScoresV3(args):
    """
        rule 1: 若准确率创新高，则 +2
        rule 2: 若准确率比期望值高，最高值低，则 +0.5
        rule 3: 若准确率比期望值低，且差值低于期望值与最高值的差，则 -0.5
        rule 4： 若准确率比期望值低，且差值大于期望值与最高值的差，则 -2
    """
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']
    expAcc_logs = args['expAcc_logs']
    
    result = {
        'scores':scores,
        'deltaAcc':deltaAcc,
        'expAcc_logs':expAcc_logs
    }
    

    if round == 0:
        return result
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if round == 1:
        if acc_logs[round] < acc_logs[round-1]:
            for cId in participants:
                scores[cId][round] -= 2
        else:
            deltaAcc = acc_logs[round]-acc_logs[round-1]
            deltaAcc = (1-acc_logs[round])*(1-math.exp(-deltaAcc))
        result['deltaAcc'] = deltaAcc
        return result
    #calculate expAcc for this round
    # print(f'round={round}, deltaAcc={deltaAcc}, acc_logs[round-1]={acc_logs[round-1]}, acc_logs[round]={acc_logs[round]}, maxAcc={np.max(acc_logs[:round])}, preMaxAcc={np.max(acc_logs[:round-1])}')

    preMaxAcc = np.max(acc_logs[:round-1])
    if acc_logs[round-1] < preMaxAcc:
        expAcc = preMaxAcc
    else:
        expAcc = acc_logs[round-1] + deltaAcc
    expAcc_logs.append(expAcc)

    nowAcc = acc_logs[round]
    maxAcc = np.max(acc_logs[:round])
    beta = 0.1
    if acc_logs[round] > maxAcc:
        deltaAcc = (1-beta)*deltaAcc + beta*(acc_logs[round]-maxAcc)
        deltaAcc = (1-acc_logs[round])*(1 - math.exp(-deltaAcc))
    result['deltaAcc'] = deltaAcc
    
    if nowAcc > maxAcc and nowAcc-maxAcc > expAcc_logs[round]-expAcc_logs[round-1]:
        for cId in participants:
            scores[cId][round] += 0.5
        return result
    if nowAcc > maxAcc and nowAcc-maxAcc < expAcc_logs[round]-expAcc_logs[round-1]:
        for cId in participants:
            scores[cId][round] += 2
        return result
    for cId in participants:
        scores[cId][round] -= 2
    return result

def updateScoresV2(args):
    """
        rule 1: 若准确率创新高，则 +2
        rule 2: 若准确率比期望值高，最高值低，则 +0.5
        rule 3: 若准确率比期望值低，且差值低于期望值与最高值的差，则 -0.5
        rule 4： 若准确率比期望值低，且差值大于期望值与最高值的差，则 -2
    """
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']
    expAcc_logs = args['expAcc_logs']
    beta = args['beta']
    
    result = {
        'scores':scores,
        'deltaAcc':deltaAcc,
        'expAcc_logs':expAcc_logs
    }

    if round == 0:
        result['deltaAcc'] = deltaAcc
        return result
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if round == 1:
        if acc_logs[round] < acc_logs[round-1]:
            for cId in participants:
                scores[cId][round] -= 2
        else:
            deltaAcc = acc_logs[round] - acc_logs[round-1]
        result['deltaAcc'] = deltaAcc
        return result
    nowAcc = acc_logs[round]
    maxAcc = np.max(acc_logs[:round])
    expAcc = acc_logs[round-1] + deltaAcc
    expAcc_logs.append(expAcc)
    
    # beta = 0.1
    deltaAcc = (1-beta)*deltaAcc + beta*(acc_logs[round]-acc_logs[round-1])
    result['deltaAcc'] = deltaAcc
    if nowAcc >= maxAcc:
        for cId in participants:
            scores[cId][round] += 2
        return result
    if maxAcc < expAcc:
        for cId in participants:
            scores[cId][round] -= 2
        return result
    if maxAcc > expAcc:
        if nowAcc < expAcc:
            for cId in participants:
                scores[cId][round] -= 2
            return result
        if nowAcc >= expAcc:
            for cId in participants:
                scores[cId][round] += 0.5
            return result
    return result

def updateScoresV1(args):
    """
        rule 1: 若准确率创新高，则 +2
        rule 2: 若准确率比期望值高，最高值低，则 +0.5
        rule 3: 若准确率比期望值低，且差值低于期望值与最高值的差，则 -0.5
        rule 4： 若准确率比期望值低，且差值大于期望值与最高值的差，则 -2
    """
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']
    expAcc_logs = args['expAcc_logs']
    beta = args['beta']
    
    result = {
        'scores':scores,
        'deltaAcc':deltaAcc,
        'expAcc_logs':expAcc_logs
    }

    if round == 0:
        return result
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if round == 1:
        if acc_logs[round] < acc_logs[round-1]:
            for cId in participants:
                scores[cId][round] -= 2
        else:
            deltaAcc = acc_logs[round] - acc_logs[round-1]
        result['deltaAcc'] = deltaAcc
        return result
    nowAcc = acc_logs[round]
    maxAcc = np.max(acc_logs[:round])
    expAcc = acc_logs[round-1] + deltaAcc
    # beta = 0.1
    deltaAcc = (1-beta)*deltaAcc + beta*(acc_logs[round]-acc_logs[round-1])
    result['deltaAcc'] = deltaAcc
    if nowAcc >= maxAcc:
        for cId in participants:
            scores[cId][round] += 2
        return result
    if maxAcc < expAcc:
        for cId in participants:
            scores[cId][round] -= 2
        return result
    if maxAcc > expAcc:
        if nowAcc < expAcc:
            for cId in participants:
                scores[cId][round] -= 2
            return result
    return result

def updateScoresV0(args):
    """
        rule 1: 若准确率创新高，则 +2
        rule 2: 若准确率比期望值高，最高值低，则 +0.5
        rule 3: 若准确率比期望值低，且差值低于期望值与最高值的差，则 -0.5
        rule 4： 若准确率比期望值低，且差值大于期望值与最高值的差，则 -2
    """
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']

    if round == 0:
        return scores, deltaAcc
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if round == 1:
        if acc_logs[round] < acc_logs[round-1]:
            for cId in participants:
                scores[cId][round] -= 2
        return scores, deltaAcc
    nowAcc = acc_logs[round]
    maxAcc = np.max(acc_logs[:round])
    expAcc = (np.sum(acc_logs[:round]) - maxAcc) * 0.3 + maxAcc * 0.7
    if nowAcc > maxAcc:
        for cId in participants:
            scores[cId][round] += 2
        return scores, deltaAcc
    if nowAcc > expAcc and (nowAcc - expAcc) < (maxAcc - expAcc):
        for cId in participants:
            scores[cId][round] += 0.5
        return scores, deltaAcc
    if nowAcc < expAcc and (expAcc - nowAcc) < (maxAcc - expAcc):
        for cId in participants:
            scores[cId][round] -= 0.5
        return scores, deltaAcc
    if nowAcc < expAcc and (expAcc - nowAcc) > (maxAcc - expAcc):
        for cId in participants:
            scores[cId][round] -= 2
        return scores, deltaAcc
    return scores, deltaAcc

def updateScores(args):
    CLIENT_N = args['CLIENT_N']
    acc_logs = args['acc_logs']
    scores = args['scores']
    prev_participants = args['prev_participants']
    participants = args['participants']
    round = args['round']
    deltaAcc = args['deltaAcc']
    expAcc_logs = args['expAcc_logs']
    
    result = {
        'scores':scores,
        'deltaAcc':deltaAcc,
        'expAcc_logs':expAcc_logs
    }

    if round == 0:
        return result
    for cId in range(CLIENT_N):
        scores[cId][round] = scores[cId][round-1]
    if round > 1:
        if acc_logs[round] - acc_logs[round-1] > acc_logs[round-1] - acc_logs[round-2]:
            for cId in prev_participants:
                    scores[cId][round] -= 1
            for cId in participants:
                    scores[cId][round] += 1
    if acc_logs[round] - acc_logs[round-1] < 0:
        for cId in participants:
            scores[cId][round] -= 1

    return result

In [10]:
def trainWithoutUpdate(args):
    globalNet = args['globalNet']
    Epoch = args['Epoch']
    BatchSize = args['BatchSize']
    ParticipantSet= args['ParticipantSet']
    Round= args['Round']
    testDst= args['testDst']
    clientTrainDstSet= args['clientTrainDstSet']
    device= args['device']
    writer = args['writer']
    goodRate = args['goodRate']
    client_n = args['Client_n']
    net_name = args['net_name']
    dst_name = args['dst_name']
    iid = args['iid']
    noise = args['noise']
    
    acc_logs = []
    participants = []
    scores = []
    for c in range(client_n):
        scores.append([0 for r in range(Round)])

    criterion = nn.CrossEntropyLoss()
    test_loader = DataLoader(testDst, batch_size=BatchSize, shuffle=False)
    for round in range(Round):
        gradsFromClient = []
        participants = ParticipantSet[round]
        for cId in participants:
            cNet = copy.deepcopy(globalNet)
            optimizer = torch.optim.SGD(cNet.parameters(), lr=0.01)
            train_loader = DataLoader(clientTrainDstSet[cId], batch_size=BatchSize)
            cNet.train()
            for epoch in range(Epoch):
                train_acc = 0.0
                train_loss = 0.0

                for i, data in enumerate(train_loader):
                    optimizer.zero_grad()
                    train_pred = cNet(data[0].to(device))
                    batch_loss = criterion(train_pred, data[1].to(device))
                    batch_loss.backward()
                    optimizer.step()
            originNet = copy.deepcopy(globalNet)
            gradsFromClient.append(diff(originNet, cNet))
        globalNet = aggregate(globalNet, gradsFromClient, len(participants))
        test_acc = 0.0
        test_loss = 0.0
        with torch.no_grad():
            for i, data in enumerate(test_loader):
                test_pred = globalNet(data[0].to(device))
                batch_loss = criterion(test_pred, data[1].to(device))

                test_acc += np.sum(np.argmax(test_pred.cpu().data.numpy(), axis=1) == data[1].numpy())
                test_loss += batch_loss.item()
            print('[%03d/%03d round] Test Acc: %3.6f Loss: %3.6f' % \
                (round, Round, test_acc/testDst.__len__(), test_loss/testDst.__len__()))
            # writer.add_scalars(f'Learnning/Acc-{net_name}-{dst_name}-{client_n}-{noise}-{iid}',{f'{goodRate}-{noise}-{iid}':test_acc/testDst.__len__()},round)
        acc_logs.append(test_acc/testDst.__len__())
    result = {
        'goodRate':goodRate,
        'dst_name':dst_name,
        'net_name':net_name,
        'client_n':client_n,
        'net':globalNet,
        'acc_logs':acc_logs,
        'ParticipantSet':ParticipantSet,
        'iid':iid,
        'noise':noise,
        'globalNet':globalNet
    }
    return result

In [None]:
#seed 需设置>=3
seed = 3
setup_seed(seed)

log_dir = 'acc_data10'
writer = SummaryWriter(log_dir)
e_name = 'random'
clientSetting = [100, 25, 5]#[100, 25, 5]
participantSetting = {5:2, 25:5, 100:10}
BatchSize = 32
Round = 101
Epoch = 1
goodRates = [[],[],[]]
done = []
for net_name in ['RESNET18','CNN','MLP']:
    for dst_name in ['CIFAR10', 'MINIST']:
        for client_n in clientSetting:
            for noise in [True, False]:
                for iid in [True, False]:
                    noise_str = 'noise' if noise else 'nonise'
                    iid_str = 'iid' if iid else 'noniid'
                    if (not(noise and iid)) or (f'{net_name}-{dst_name}-{client_n}-{noise_str}-{iid_str}' in done):
                        continue
                    # set random seeds
                    setup_seed(seed)
                
                    print(f'{net_name}-{dst_name}-{client_n}-{noise_str}-{iid_str}')

                    participant_n = participantSetting[client_n]
                    clientTrainDstSet, testDst = initDst(dst_name, client_n, iid, noise)
                    globalNet = Classifier(net_name, dst_name).to(device)
                    ParticipantSet = [selectParticipants(participant_n, client_n, 0, goodRates[r%3]) for r in range(Round)]

                    args = {
                    'net_name':net_name,
                    'dst_name':dst_name,
                    'Client_n':client_n,
                    'globalNet': copy.deepcopy(globalNet),
                    'Epoch':Epoch,
                    'BatchSize':BatchSize,
                    'ParticipantSet':ParticipantSet,
                    'Round':len(ParticipantSet),
                    'testDst':testDst,
                    'clientTrainDstSet':clientTrainDstSet,
                    'device':device,
                    'writer':writer,
                    'goodRate':e_name,
                    'goodRates':goodRates,
                    'noise':noise_str,
                    'iid':iid_str
                    }
                    res = trainWithoutUpdate(args)
                    torch.save(res, f'{log_dir}/{net_name}-{dst_name}-{client_n}-{noise_str}-{iid_str}.save')
                    done.append(f'{net_name}-{dst_name}-{client_n}-{noise_str}-{iid_str}')



In [20]:
timm.list_models()

['adv_inception_v3',
 'bat_resnext26ts',
 'beit_base_patch16_224',
 'beit_base_patch16_224_in22k',
 'beit_base_patch16_384',
 'beit_large_patch16_224',
 'beit_large_patch16_224_in22k',
 'beit_large_patch16_384',
 'beit_large_patch16_512',
 'beitv2_base_patch16_224',
 'beitv2_base_patch16_224_in22k',
 'beitv2_large_patch16_224',
 'beitv2_large_patch16_224_in22k',
 'botnet26t_256',
 'botnet50ts_256',
 'cait_m36_384',
 'cait_m48_448',
 'cait_s24_224',
 'cait_s24_384',
 'cait_s36_384',
 'cait_xs24_384',
 'cait_xxs24_224',
 'cait_xxs24_384',
 'cait_xxs36_224',
 'cait_xxs36_384',
 'coat_lite_mini',
 'coat_lite_small',
 'coat_lite_tiny',
 'coat_mini',
 'coat_tiny',
 'coatnet_0_224',
 'coatnet_0_rw_224',
 'coatnet_1_224',
 'coatnet_1_rw_224',
 'coatnet_2_224',
 'coatnet_2_rw_224',
 'coatnet_3_224',
 'coatnet_3_rw_224',
 'coatnet_4_224',
 'coatnet_5_224',
 'coatnet_bn_0_rw_224',
 'coatnet_nano_cc_224',
 'coatnet_nano_rw_224',
 'coatnet_pico_rw_224',
 'coatnet_rmlp_0_rw_224',
 'coatnet_rmlp_1_rw

In [117]:
import timm
net = timm.create_model('vgg19',pretrained=False,num_classes=10).to('cuda')

In [118]:
clientTrainDstSet, testDst = initDst("CIFAR10",1,True,False)
dataLoader = DataLoader(clientTrainDstSet[0],batch_size=32)
# net = Classifier("MOBILENETV2","CIFAR10")
start = time.time()
for i,data in enumerate(dataLoader):
    net(data[0].cuda())
stop = time.time()
stop - start

21.11084508895874

In [97]:
net = Classifier("CNN","CIFAR10").to('cuda')
start = time.time()
for i,data in enumerate(dataLoader):
    net(data[0].cuda())
stop = time.time()
stop - start

1.5756676197052002

In [19]:
torch.cuda.set_device

torch.cuda.device

In [None]:
# args = {
#     'CLIENT_N':100,
#     'acc_logs':res['acc_logs'],
#     'scores':res['scores'],
#     'prev_participants':res['ParticipantSet'][0],
#     'participants':res['ParticipantSet'][1],
#     'deltaAcc':res['deltaAcc'],
#     'expAcc_logs':res['expAcc_logs']
#     }
scoresBaseline = []
Round = len(res['acc_logs'])
Client_n = len(res['scores'])

for c in range(Client_n):
    scoresBaseline.append([0 for r in range(Round)])
sco_res = {}
sco_res['expAcc_logs'] = [0,0]
sco_res['deltaAcc'] = 0
for round in range(Round):
    args = {
    'CLIENT_N':Client_n,
    'acc_logs':res['acc_logs'],
    'round':round,
    'scores':scoresBaseline,
    'deltaAcc':sco_res['deltaAcc'],
    'expAcc_logs':sco_res['expAcc_logs']
    }
    if round == 0:
        args['prev_participants'] = []
    else:
        args['prev_participants'] = res['ParticipantSet'][round-1]
    args['participants'] = res['ParticipantSet'][round]
    sco_res = updateScores(args)

In [64]:
scoresV = []
Round = len(res['acc_logs'])
Client_n = len(res['scores'])

for c in range(Client_n):
    scoresV.append([0 for r in range(Round)])
sco_res = {}
sco_res['expAcc_logs'] = [0,0]
sco_res['deltaAcc'] = 0
for round in range(Round):
    args = {
    'CLIENT_N':Client_n,
    'acc_logs':res['acc_logs'],
    'round':round,
    'scores':scoresV,
    'deltaAcc':sco_res['deltaAcc'],
    'expAcc_logs':sco_res['expAcc_logs'],
    'beta':0.2
    }
    if round == 0:
        args['prev_participants'] = []
    else:
        args['prev_participants'] = res['ParticipantSet'][round-1]
    args['participants'] = res['ParticipantSet'][round]
    sco_res = updateScoresV1(args)

In [66]:
v = 0
for r in range(2, 249):
    scores = scoresBaseline
    x1 = []
    y1 = []
    for c in range(100):
        x1.append(scores[c][r])
        y1.append(c)

    scores = scoresV
    x2 = []
    y2 = []
    for c in range(100):
        x2.append(scores[c][r])
        y2.append(c)
    s1 = scipy.stats.spearmanr(x1,y1)[0]
    s2 = scipy.stats.spearmanr(x2,y2)[0]
    v += (s2-s1)
    print(r,'baseline:',s1,'new:',s2)
print(v)

2 baseline: 0.13391213504629793 new: 0.12009486088461248
3 baseline: 0.13391213504629793 new: 0.03637488574870473
4 baseline: 0.11561238069626845 new: 0.02798374335218187
5 baseline: 0.2003669628878393 new: 0.06757049764406133
6 baseline: 0.2003669628878393 new: 0.05895544097633943
7 baseline: 0.2003669628878393 new: 0.09585546881288594
8 baseline: 0.07002795012393423 new: 0.045077628605938766
9 baseline: 0.13599641601467816 new: 0.11605620380213914
10 baseline: 0.20448622543202136 new: 0.17329878986160643
11 baseline: 0.17090663469730286 new: 0.19299981713985076
12 baseline: 0.17090663469730286 new: 0.1942568459215994
13 baseline: 0.09303044945270524 new: 0.15546531875379743
14 baseline: 0.13632509803416457 new: 0.18899688892853894
15 baseline: 0.15738238759340978 new: 0.19277748840407866
16 baseline: 0.14772482478884152 new: 0.18250212884975228
17 baseline: 0.1393397169479339 new: 0.16607220757146043
18 baseline: 0.12128701543908726 new: 0.17112553939989983
19 baseline: 0.12128701543