In [None]:
import torch
import numpy as np
import pandas as pd

from models.resnet import ResNet50
from utils.reproducibility import make_it_reproducible, seed_worker
from utils.datasets import get_datasets
from utils.sampling import get_user_groups

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
tot_users = 100

In [None]:
# reproducibility
seed = 0

g = torch.Generator()

In [None]:
# datasets and loaders
trainset, testset = get_datasets()

In [None]:
# iid balanced
iid = True
unbalanced = False

make_it_reproducible(seed)
g.manual_seed(seed)
testloader = torch.utils.data.DataLoader(testset,
                                     batch_size=128, shuffle=False, num_workers=2,
                                     worker_init_fn=seed_worker, generator=g)
user_groups, user_dict_cls_count = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_users)
labels = trainset.targets

df = pd.DataFrame(user_dict_cls_count)
m = df.count(axis=0)
print("seen classes: ", m.mean(), "+-", m.std())  # 10.0 +- 0.0

In [None]:
# non iid balanced
iid = False
unbalanced = False

make_it_reproducible(seed)
g.manual_seed(seed)
testloader = torch.utils.data.DataLoader(testset,
                                     batch_size=128, shuffle=False, num_workers=2,
                                     worker_init_fn=seed_worker, generator=g)
user_groups, user_dict_cls_count = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_users)
labels = trainset.targets

df = pd.DataFrame(user_dict_cls_count)
df.replace(0, np.nan, inplace=True)
m = df.count(axis=0)
print("seen classes: ", m.mean(), "+-", m.std())  # 1.94 +- 0.23868325657594203

In [None]:
# non iid unbalanced
iid = False
unbalanced = True

make_it_reproducible(seed)
g.manual_seed(seed)
testloader = torch.utils.data.DataLoader(testset,
                                     batch_size=128, shuffle=False, num_workers=2,
                                     worker_init_fn=seed_worker, generator=g)
user_groups, user_dict_cls_count = get_user_groups(trainset, iid=iid, unbalanced=unbalanced, tot_users=tot_users)
labels = trainset.targets

df = pd.DataFrame(user_dict_cls_count)
df.replace(0, np.nan, inplace=True)
m = df.count(axis=0)
print("seen classes: ", m.mean(), "+-", m.std())  # 1.84 +- 1.0609696346533988