In [1]:
import torch
from einsum import EinsumNetwork
from einsum import Graph
from einsum import EinetMixture
from torchvision.datasets import MNIST, CelebA
from torch.utils.data import DataLoader, TensorDataset, Subset
import numpy as np
import matplotlib.pyplot as plt
import os
import utils
from PIL import Image

In [2]:
mnist = MNIST('../../datasets/', train=True)

RuntimeError: Dataset not found. You can use download=True to download it

In [3]:
celeba = CelebA('/storage-01/datasets/', split='train')

In [5]:
celeba.attr

tensor([[0, 1, 1,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        ...,
        [1, 0, 1,  ..., 0, 1, 1],
        [0, 0, 0,  ..., 0, 0, 1],
        [0, 1, 1,  ..., 1, 0, 1]])

In [3]:
def distribute_uneven_to_clients(num_clients, labels_per_client, dataset):
    all_labels = list(torch.unique(dataset.targets).numpy())
    label_assignments = {c: [] for c in range(num_clients)}
    for c in range(num_clients):
        labels = np.random.choice(all_labels, labels_per_client, replace=False)
        all_labels = [l for l in all_labels if l not in labels]
        label_assignments[c] += list(labels)
    
    datasets = []
    for c, labels in label_assignments.items():
        inds = []
        for l in labels:
            idx = torch.argwhere(dataset.targets == l).flatten()
            inds.append(idx)
        inds = torch.concat(inds)
        subset = TensorDataset(dataset.data[inds], dataset.targets[inds])
        datasets.append(subset)
    return datasets

In [36]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#exponential_family = EinsumNetwork.BinomialArray
#exponential_family = EinsumNetwork.CategoricalArray
exponential_family = EinsumNetwork.NormalArray

classes = [7]
# classes = [2, 3, 5, 7]
# classes = None

K = 20

structure = 'poon-domingos'
# structure = 'binary-trees'

# 'poon-domingos'
pd_num_pieces = [4]
# pd_num_pieces = [7]
# pd_num_pieces = [7, 28]
width = 28
height = 28

# 'binary-trees'
depth = 4
num_repetitions = 20

num_epochs = 10
batch_size = 100
online_em_frequency = 1
online_em_stepsize = 0.1
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

client_datasets = distribute_uneven_to_clients(5, 2, mnist)

# Make EinsumNetwork
######################################
pd_delta = [[height / d, width / d] for d in pd_num_pieces]

client_spns = []
# Train
######################################
for ds in client_datasets:
    cluster_spns = []
    for l in torch.unique(ds.tensors[1]):
        idx = torch.argwhere(ds.tensors[1] == l).flatten()
        subset = TensorDataset(ds.tensors[0][idx], ds.tensors[1][idx])
        graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        args = EinsumNetwork.Args(
                num_var=width*height,
                num_dims=1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                online_em_frequency=online_em_frequency,
                online_em_stepsize=online_em_stepsize)

        einet = EinsumNetwork.EinsumNetwork(graph, args)
        einet.initialize()
        einet.to(device)

        loader = DataLoader(subset, batch_size)

        for _ in range(num_epochs):

            total_ll = 0.0
            for x,  y in loader:
                x = x.to(device, dtype=torch.float32)
                x = x.reshape(x.shape[0], width*height)
                x = x / 255.
                x -= 0.5
                outputs = einet.forward(x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.mean()
                log_likelihood.backward()

                einet.em_process_batch()
                total_ll += log_likelihood.detach().item()
            
            print(total_ll / len(loader))

            einet.em_update()
        
        cluster_spns.append(einet)
    mixture = EinetMixture.EinetMixture([0.5, 0.5], cluster_spns)
    client_spns.append(mixture)

weights = [1/len(ds) for ds in client_datasets]
weights = np.array(weights) / sum(weights)
#mixture = EinetMixture.EinetMixture(weights, client_spns)


samples_dir = '../samples/demo_mnist/'
utils.mkdir_p(samples_dir)

#####################
# draw some samples #
#####################
samples = []
for _ in range(5):
    client_spn_idx = np.random.choice(np.arange(5), 1, p=weights)[0]
    mixture = client_spns[client_spn_idx]
    client_samples = mixture.sample(5)
    client_samples = client_samples.reshape((-1, 28, 28))
    samples.append(client_samples)
samples = np.vstack(samples)
utils.save_image_stack(samples, 5, 5, os.path.join(samples_dir, "samples.png"), margin_gray_val=0.)

{0}
-281.92253820510473
-0.09855388130338671
-4.180915618961729e-05
-2.3583152581101572e-07
-2.385027104962239e-07
-2.6208573089547826e-07
-2.428437926566323e-07
-2.1483619756382082e-07
-1.7981645161104525e-07
-1.895002744004494e-07
{0}
-319.06677160461743
-0.2593533400213346
-0.000416682500851806
-2.881304685805238e-07
-2.9907810793853666e-07
-2.4938733697392003e-07
-2.772230935477182e-07
-2.940706544052318e-07
-3.0396413563948954e-07
-3.328136382189465e-07
{784}
{0}
-319.6166894654433
-0.2592997021973133
-0.00041679473633469873
-3.309361588321735e-07
-3.5477808827029854e-07
-3.111978960153768e-07
-3.319295444725867e-07
-3.5378469031381126e-07
-3.7935418551645246e-07
-3.061010289684418e-07
{0}
-303.85892143892863
-0.18010421884837485
-0.00019555345598098664
-4.0209543851146703e-07
-4.1059597299589235e-07
-4.304932579161109e-07
-4.446702614731048e-07
-4.44699303497622e-07
-4.2388499031807845e-07
-4.1534086775870476e-07
{784}
{0}
-308.95159938835326
-0.20330479981437807
-0.0002528121326

  samples /= samples.max()
  img = Image.fromarray(np.round(framed_img * 255.).astype(np.uint8))


In [46]:
weights

array([0.09550203, 0.10110176, 0.11037082, 0.10101641, 0.10225948,
       0.08874521, 0.10241702, 0.09758934, 0.10057492, 0.100423  ])

In [47]:
for ds in client_datasets:
    print(torch.unique(ds.tensors[1]))

tensor([7])
tensor([6])
tensor([5])
tensor([0])
tensor([8])
tensor([1])
tensor([4])
tensor([3])
tensor([9])
tensor([2])


In [66]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#exponential_family = EinsumNetwork.BinomialArray
#exponential_family = EinsumNetwork.CategoricalArray
exponential_family = EinsumNetwork.NormalArray

classes = [7]
# classes = [2, 3, 5, 7]
# classes = None

K = 10

structure = 'poon-domingos'
# structure = 'binary-trees'

# 'poon-domingos'
pd_num_pieces = [4]
# pd_num_pieces = [7]
# pd_num_pieces = [7, 28]
width = 28
height = 28

# 'binary-trees'
depth = 3
num_repetitions = 20

num_epochs = 10
batch_size = 100
online_em_frequency = 5
online_em_stepsize = 0.2
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

client_datasets = distribute_uneven_to_clients(10, 1, mnist)

# Make EinsumNetwork
######################################
pd_delta = [[height / d, width / d] for d in pd_num_pieces]

client_spns = []
# Train
######################################
for ds in client_datasets:
    graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
    args = EinsumNetwork.Args(
            num_var=width*height,
            num_dims=1,
            num_classes=1,
            num_sums=K,
            num_input_distributions=K,
            exponential_family=exponential_family,
            exponential_family_args=exponential_family_args,
            online_em_frequency=online_em_frequency,
            online_em_stepsize=online_em_stepsize)
    einet = EinsumNetwork.EinsumNetwork(graph, args)
    einet.initialize()
    einet.to(device)
    loader = DataLoader(ds, batch_size)
    for _ in range(num_epochs):
        total_ll = 0.0
        for x,  y in loader:
            x = x.to(device, dtype=torch.float32)
            x = x.reshape(x.shape[0], width*height)
            x = x / 255.
            x -= 0.5
            outputs = einet.forward(x)
            ll_sample = EinsumNetwork.log_likelihoods(outputs)
            log_likelihood = ll_sample.mean()
            log_likelihood.backward()
            einet.em_process_batch()
            total_ll += log_likelihood.detach().item()
        
        print(total_ll / len(loader))
        einet.em_update()
    
    client_spns.append(einet)
weights = [1/len(ds) for ds in client_datasets]
weights = np.array(weights) / sum(weights)
mixture = EinetMixture.EinetMixture(weights, client_spns)

#mixture = EinetMixture.EinetMixture(weights, client_spns)


samples_dir = '../samples/demo_mnist/'
utils.mkdir_p(samples_dir)

#####################
# draw some samples #
#####################
# mixture acts as a server (i.e. sum node)
# root node samples client-spn and the subsequent sampling takes places in one specific client SPN
samples = mixture.sample(25)
samples = samples.reshape((-1, 28, 28))
utils.save_image_stack(samples, 5, 5, os.path.join(samples_dir, "samples.png"), margin_gray_val=0.)

{0}
-471.14187668844806
699.9791213213388
1387.4753211069915
2050.4711003707625
2686.35739704714
3005.1669011520125
3033.41798116393
3043.9710879568324
3052.5208760924256
3058.3666123212392
{0}
-433.8395733833313
778.3675145467122
1511.4768534342447
2231.1450256347657
2919.468473307292
3224.9679931640626
3250.142370605469
3260.20166015625
3268.3503580729166
3275.2489217122397
{0}
-386.98157767644005
865.1510872008308
1648.2354852585565
2412.012205093626
3103.8409481956846
3323.7051362537204
3348.2815987723216
3361.2775142609125
3370.6039225260415
3372.9791124131943
{0}
-471.668329501051
729.8643772965771
1435.273710606462
2108.7645408501057
2764.2479723914194
3105.0017007084216
3139.9830342955506
3156.3605170815677
3167.6781481726694
3177.6742129568324
{0}
-424.4090095066255
751.0130644767515
1466.7628882623487
2165.520978373866
2812.782470703125
3044.7651603452623
3073.3572998046875
3090.437058971774
3102.3704833984375
3110.2955952305947
{0}
-474.76156714757286
679.9155232747396
1351.

In [77]:
samples = mixture.sample(36)
samples = samples.reshape((-1, 28, 28)) + 0.5
samples[samples < 1e-2] = 0
utils.save_image_stack(samples, 6, 6, os.path.join(samples_dir, "samples.pdf"), margin_gray_val=0.)

In [82]:
for c, ds in enumerate(client_datasets):
    s = ds.tensors[0][0].numpy()
    s -= s.min()
    s = s / s.max()
    img = Image.fromarray(np.round(s * 255.).astype(np.uint8))

    #img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3))
    img.save(f'Client_{c}.pdf')

In [86]:
def distribute_to_clients_vertical(num_clients, dataset, width=28):
    width_idx = np.arange(0, width)
    width_idx = np.array_split(width_idx, num_clients)
    
    datasets = []
    for idx in width_idx:
        subset = dataset.data[:, idx, :]
        tensor_ds = TensorDataset(subset, dataset.targets)
        datasets.append(tensor_ds)
    return datasets

In [107]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

#exponential_family = EinsumNetwork.BinomialArray
#exponential_family = EinsumNetwork.CategoricalArray
exponential_family = EinsumNetwork.NormalArray

classes = [7]
# classes = [2, 3, 5, 7]
# classes = None

K = 10

structure = 'poon-domingos'
# structure = 'binary-trees'

# 'poon-domingos'
pd_num_pieces = [4]
# pd_num_pieces = [7]
# pd_num_pieces = [7, 28]
width = 14
height = 28

# 'binary-trees'
depth = 4
num_repetitions = 20

num_epochs = 10
batch_size = 100
online_em_frequency = 5
online_em_stepsize = 0.5
############################################################################

exponential_family_args = None
if exponential_family == EinsumNetwork.BinomialArray:
    exponential_family_args = {'N': 255}
if exponential_family == EinsumNetwork.CategoricalArray:
    exponential_family_args = {'K': 256}
if exponential_family == EinsumNetwork.NormalArray:
    exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

client_datasets = distribute_to_clients_vertical(2, mnist)

# Make EinsumNetwork
######################################
pd_delta = [[height / d, width / d] for d in pd_num_pieces]

client_spns = []
# Train
######################################
for ds in client_datasets:
    cluster_spns = []
    for l in torch.unique(ds.tensors[1]):
        idx = torch.argwhere(ds.tensors[1] == l).flatten()
        subset = TensorDataset(ds.tensors[0][idx], ds.tensors[1][idx])
        graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
        args = EinsumNetwork.Args(
                num_var=width*height,
                num_dims=1,
                num_classes=1,
                num_sums=K,
                num_input_distributions=K,
                exponential_family=exponential_family,
                exponential_family_args=exponential_family_args,
                online_em_frequency=online_em_frequency,
                online_em_stepsize=online_em_stepsize)

        einet = EinsumNetwork.EinsumNetwork(graph, args)
        einet.initialize()
        einet.to(device)

        loader = DataLoader(subset, batch_size)

        for _ in range(num_epochs):

            total_ll = 0.0
            for x,  y in loader:
                x = x.to(device, dtype=torch.float32)
                x = x.reshape(x.shape[0], width*height)
                x = x / 255.
                x -= 0.5
                outputs = einet.forward(x)
                ll_sample = EinsumNetwork.log_likelihoods(outputs)
                log_likelihood = ll_sample.mean()
                log_likelihood.backward()

                einet.em_process_batch()
                total_ll += log_likelihood.detach().item()
            
            print(total_ll / len(loader))

            einet.em_update()
        
        cluster_spns.append(einet)
    mixture = EinetMixture.EinetMixture([1/10]*10, cluster_spns)
    client_spns.append(mixture)

#mixture = EinetMixture.EinetMixture(weights, client_spns)
samples_dir = '../samples/demo_mnist/'
utils.mkdir_p(samples_dir)


{0}
290.69135278066
1298.9224344889324
1376.5811501820883
1434.7354400634765
1479.0014882405599
1496.64782816569
1499.2635864257813
1499.8894836425782
1501.0454671223958
1502.028109741211
{0}
512.2332072158289
1800.7119566973518
1952.5888317893532
1967.0436728982365
1969.349754333496
1965.7089649649226
1968.6639708350685
1967.5700237049775
1975.9390442792107
1971.7925196254955
{0}
299.5292706171671
1333.311328125
1586.0936264038087
1635.9005477905273
1646.0467641194662
1648.613887532552
1648.3629801432292
1649.8374420166015
1650.1348795572917
1650.3134012858072
{0}
333.2425199016448
1366.7570478377804
1561.09522468813
1606.9933294480848
1626.2393414897304
1624.0075413488573
1619.0718127835182
1615.9321948635964
1628.2056796166205
1656.8710287770916
{0}
305.7145930468026
1399.8753589694784
1697.3178255760063
1727.3938443458687
1737.810350321107
1739.8213056144068
1740.8557335805085
1742.4970537605932
1749.2369260626324
1747.2952156713454
{0}
244.97425759055398
1304.7158114346591
1633.24

In [108]:
# sampling below acts as a product node (sample both branches since they have different scopes)
samples = []
for mixture in client_spns:
    cluster_idx = np.random.randint(0, 9)
    client_samples = mixture.einets[cluster_idx].sample(25)
    client_samples = client_samples.reshape((-1, 14, 28))
    samples.append(client_samples)

matched_samples = []
for i in range(25):
    cluster_idx = np.random.randint(0, 9)
    client_samples = [mixture.einets[cluster_idx].sample(1) for mixture in client_spns]
    client_samples = [cs.reshape(14, 28) for cs in client_samples]
    sample = np.concatenate(client_samples, axis=0)
    matched_samples.append(sample)
matched_samples = np.array(matched_samples)
matched_samples += 0.5
matched_samples[matched_samples < 1e-2] = 0
utils.save_image_stack(matched_samples, 5, 5, os.path.join(samples_dir, "samples.pdf"), margin_gray_val=0.)

In [110]:
for c, ds in enumerate(client_datasets):
    idx = torch.argwhere(ds.tensors[1] == 0).flatten()
    s = ds.tensors[0][idx[0]].numpy()
    s -= s.min()
    s = s / s.max()
    img = Image.fromarray(np.round(s * 255.).astype(np.uint8))

    #img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3))
    img.save(f'Client_{c}.pdf')

In [112]:
def distribute_to_clients_hybrid(dataset):
    # TODO: add support for more than 2 clients
    # upper 10 pixels -> client 1
    # lower 10 pixels -> client 2
    # remaining 8 -> split client 1 and 2
    
    datasets_client1 = []
    datasets_client2 = []
    data = dataset.data
    samples_client1 = np.random.choice(np.arange(len(data)), int(0.6*len(data)), False)
    samples_client2 = np.random.choice(np.arange(len(data)), int(0.6*len(data)), False)
    data_c1 = data[:, :10, :]
    data_c2 = data[:, 18:, :]
    data_both = data[:, 10:18, :]
    data_c1 = data_c1[samples_client1]
    data_c2 = data_c2[samples_client2]
    data_both_c1 = data_both[samples_client1]
    data_both_c2 = data_both[samples_client2]

    datasets_client1.append(TensorDataset(data_c1, dataset.targets[samples_client1]))
    datasets_client1.append(TensorDataset(data_both_c1, dataset.targets[samples_client1]))
    datasets_client2.append(TensorDataset(data_c2, dataset.targets[samples_client2]))
    datasets_client2.append(TensorDataset(data_both_c2, dataset.targets[samples_client2]))
    return datasets_client1, datasets_client2

In [116]:
def train_client(client_idx):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    #exponential_family = EinsumNetwork.BinomialArray
    #exponential_family = EinsumNetwork.CategoricalArray
    exponential_family = EinsumNetwork.NormalArray

    classes = [7]
    # classes = [2, 3, 5, 7]
    # classes = None

    K = 10

    structure = 'poon-domingos'
    # structure = 'binary-trees'

    # 'poon-domingos'
    pd_num_pieces = [4]
    # pd_num_pieces = [7]
    # pd_num_pieces = [7, 28]

    # 'binary-trees'
    depth = 4
    num_repetitions = 20

    num_epochs = 10
    batch_size = 100
    online_em_frequency = 5
    online_em_stepsize = 0.5
    ############################################################################

    exponential_family_args = None
    if exponential_family == EinsumNetwork.BinomialArray:
        exponential_family_args = {'N': 255}
    if exponential_family == EinsumNetwork.CategoricalArray:
        exponential_family_args = {'K': 256}
    if exponential_family == EinsumNetwork.NormalArray:
        exponential_family_args = {'min_var': 1e-6, 'max_var': 0.1}

    client_datasets = distribute_to_clients_hybrid(mnist)
    client_datasets = client_datasets[client_idx]

    client_spns = []
    # Train
    ######################################
    for ds in client_datasets:
        width = 28
        height = ds.tensors[0].shape[1]
        pd_delta = [[height / d, width / d] for d in pd_num_pieces]
        cluster_spns = []
        for l in torch.unique(ds.tensors[1]):
            idx = torch.argwhere(ds.tensors[1] == l).flatten()
            subset = TensorDataset(ds.tensors[0][idx], ds.tensors[1][idx])
            graph = Graph.poon_domingos_structure(shape=(height, width), delta=pd_delta)
            args = EinsumNetwork.Args(
                    num_var=width*height,
                    num_dims=1,
                    num_classes=1,
                    num_sums=K,
                    num_input_distributions=K,
                    exponential_family=exponential_family,
                    exponential_family_args=exponential_family_args,
                    online_em_frequency=online_em_frequency,
                    online_em_stepsize=online_em_stepsize)

            einet = EinsumNetwork.EinsumNetwork(graph, args)
            einet.initialize()
            einet.to(device)

            loader = DataLoader(subset, batch_size)

            for _ in range(num_epochs):

                total_ll = 0.0
                for x,  y in loader:
                    x = x.to(device, dtype=torch.float32)
                    x = x.reshape(x.shape[0], width*height)
                    x = x / 255.
                    x -= 0.5
                    outputs = einet.forward(x)
                    ll_sample = EinsumNetwork.log_likelihoods(outputs)
                    log_likelihood = ll_sample.mean()
                    log_likelihood.backward()

                    einet.em_process_batch()
                    total_ll += log_likelihood.detach().item()
                
                print(total_ll / len(loader))

                einet.em_update()
            
            cluster_spns.append(einet)
        mixture = EinetMixture.EinetMixture([1/10]*10, cluster_spns)
        client_spns.append((mixture, ds.tensors[0].shape))

    return client_spns


In [117]:
client_1_spns = train_client(0)
client_2_spns = train_client(1)

{0}
48.79355286227332
658.681037902832
1156.6049431694878
1286.4626770019531
1291.511237250434
1295.4705742730034
1297.5634765625
1297.9241739908855
1297.739498562283
1298.117190890842
{0}
112.29085185527802
860.2025329589844
1428.509243774414
1498.2489959716797
1500.487322998047
1501.2830596923827
1502.2536376953126
1502.8927520751954
1503.0704650878906
1503.2081329345704
{0}
41.36551562945048
636.1538196139866
1122.5905541314019
1258.8404303656685
1267.6280551486545
1270.1524454752605
1271.5637580023872
1272.5049879286025
1273.7854749891494
1274.6893581814236
{0}
59.4661103583671
677.0873800741659
1173.5051319019215
1289.9712936813767
1297.8467192778717
1301.0644960145692
1302.6331127269848
1303.452646616343
1303.9792777396538
1304.653782200169
{0}
66.98354278670416
705.442975362142
1241.1230044894749
1383.5114440917969
1390.59864976671
1394.9786716037327
1396.6043802897136
1397.667731391059
1399.417704264323
1399.5789421929253
{0}
32.61877089037615
624.6619154986213
1121.06516580020

In [121]:
# sample from FedSPN
# FedSPN looks as following if we connect client networks:
"""
        +
        |
        x
    ----------
    |    |   |
  [:10]  +  [10:]
     --------
     |      |
 [10:18]  [10:18]
"""
# where  + is a sum node, x is a product node and [B:C] is Client SPN with scope from pixels B to C
# We sample SPN [:10] AND [10:] since they are connected via x
# then we sample either [10:18] of client 1 or client 2 according to how many samples they hold
(client_1_single, client_1_single_shape), (client_1_both, client_1_both_shape) = client_1_spns
(client_2_single, client_2_single_shape), (client_2_both, client_2_both_shape) = client_2_spns
# clients agree which label to sample
labels = np.random.randint(0, 9, size=25).flatten()
samples = []
for l in labels:
    # sample from l-branch both "single" networks
    single_sample_1 = client_1_single.einets[l].sample(1).numpy()
    single_sample_2 = client_2_single.einets[l].sample(1).numpy()
    # decide which client's SPN should be queried
    weights = np.array([client_1_both_shape[0], client_2_both_shape[0]])
    weights = weights / sum(weights)
    sample_idx = np.random.choice([0, 1], size=1, p=weights)[0]
    if sample_idx == 0:
        both_sample = client_1_both.einets[l].sample(1).numpy()
        both_sample = both_sample.reshape(client_1_both_shape[1], client_1_both_shape[2])
    else:
        both_sample = client_2_both.einets[l].sample(1).numpy()
        both_sample = both_sample.reshape(client_2_both_shape[1], client_2_both_shape[2])
    
    # glue samples together
    single_sample_1 = single_sample_1.reshape(client_1_single_shape[1], client_1_single_shape[2])
    single_sample_2 = single_sample_2.reshape(client_2_single_shape[1], client_2_single_shape[2])
    sample = np.vstack([single_sample_1, both_sample, single_sample_2])
    samples.append(sample)

samples = np.array(samples)
samples += 0.5
samples[samples < 1e-2] = 0
utils.save_image_stack(samples, 5, 5, os.path.join(samples_dir, "samples.pdf"), margin_gray_val=0.)

In [129]:
client_datasets = distribute_to_clients_hybrid(mnist)
for c, ds in enumerate(client_datasets):
    ds1, ds2 = ds
    idx = torch.argwhere(ds1.tensors[1] == 0)[0]
    s1 = ds1.tensors[0][idx].numpy()[0]
    s2 = ds2.tensors[0][idx].numpy()[0]
    s1 -= s1.min()
    s1 = s1 / s1.max()
    s2 -= s2.min()
    s2 = s2 / s2.max()
    img1 = Image.fromarray(np.round(s1 * 255.).astype(np.uint8))
    img2 = Image.fromarray(np.round(s2 * 255).astype(np.uint8))

    #img = margin_gray_val * np.ones((height*num_rows + (num_rows-1)*margin, width*num_columns + (num_columns-1)*margin, 3))
    img1.save(f'Client_{c}_1.pdf')
    img2.save(f'Client_{c}_2.pdf')

: 