In [1]:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader, Subset
import torch.nn as nn
import torch
import pandas as pd

### Download and Transform Test Data

In [2]:
transform = transforms.Compose([transforms.ToTensor()])
test_data = datasets.FashionMNIST(root='/data', train=False, transform=transform, download=True)

### Gather the indices of each class

In [3]:
class_indices = {0: [], 1: [], 2: [], 3: [], 4: [], 5: [], 6: [], 7: [], 8: [], 9: []}

for idx, target in enumerate(test_data.targets):
    class_indices[int(target)].append(idx)

### Create each client's test subsets according to: 
#### ρ = (# of OOD samples / # of ID samples)

In [4]:
# For client 1 with ID the classes [0, 1, 2, 3, 4] and OOD the classes [5, 6, 7, 8, 9]
c1_r0_indices = class_indices[0] + class_indices[1] + class_indices[2] + class_indices[3] + class_indices[4]
c1_r0_2_indices = c1_r0_indices + class_indices[5][:200] + class_indices[6][:200] + class_indices[7][:200] + class_indices[8][:200] + class_indices[9][:200]
c1_r0_4_indices = c1_r0_indices + class_indices[5][:400] + class_indices[6][:400] + class_indices[7][:400] + class_indices[8][:400] + class_indices[9][:400]
c1_r0_6_indices = c1_r0_indices + class_indices[5][:600] + class_indices[6][:600] + class_indices[7][:600] + class_indices[8][:600] + class_indices[9][:600]
c1_r0_8_indices = c1_r0_indices + class_indices[5][:800] + class_indices[6][:800] + class_indices[7][:800] + class_indices[8][:800] + class_indices[9][:800]
c1_r1_indices = c1_r0_indices + class_indices[5] + class_indices[6] + class_indices[7] + class_indices[8] + class_indices[9]

len(c1_r0_indices), len(c1_r0_2_indices), len(c1_r0_4_indices), len(c1_r0_6_indices), len(c1_r0_8_indices), len(c1_r1_indices)

(5000, 6000, 7000, 8000, 9000, 10000)

In [5]:
# For client 2 with ID the classes [5, 6, 7, 8, 9] and OOD the classes [0, 1, 2, 3, 4]
c2_r0_indices = class_indices[5] + class_indices[6] + class_indices[7] + class_indices[8] + class_indices[9]
c2_r0_2_indices = c2_r0_indices + class_indices[0][:200] + class_indices[1][:200] + class_indices[2][:200] + class_indices[3][:200] + class_indices[4][:200]
c2_r0_4_indices = c2_r0_indices + class_indices[0][:400] + class_indices[1][:400] + class_indices[2][:400] + class_indices[3][:400] + class_indices[4][:400]
c2_r0_6_indices = c2_r0_indices + class_indices[0][:600] + class_indices[1][:600] + class_indices[2][:600] + class_indices[3][:600] + class_indices[4][:600]
c2_r0_8_indices = c2_r0_indices + class_indices[0][:800] + class_indices[1][:800] + class_indices[2][:800] + class_indices[3][:800] + class_indices[4][:800]
c2_r1_indices = c2_r0_indices + class_indices[0] + class_indices[1] + class_indices[2] + class_indices[3] + class_indices[4]

len(c2_r0_indices), len(c2_r0_2_indices), len(c2_r0_4_indices), len(c2_r0_6_indices), len(c2_r0_8_indices), len(c2_r1_indices) 

(5000, 6000, 7000, 8000, 9000, 10000)

In [6]:
c1_r0_subset = Subset(dataset=test_data, indices=c1_r0_indices)
c1_r0_2_subset = Subset(dataset=test_data, indices=c1_r0_2_indices)
c1_r0_4_subset = Subset(dataset=test_data, indices=c1_r0_4_indices)
c1_r0_6_subset = Subset(dataset=test_data, indices=c1_r0_6_indices)
c1_r0_8_subset = Subset(dataset=test_data, indices=c1_r0_8_indices)
c1_r1_subset = Subset(dataset=test_data, indices=c1_r1_indices)

len(c1_r0_subset), len(c1_r0_2_subset), len(c1_r0_4_subset), len(c1_r0_6_subset), len(c1_r0_8_subset), len(c1_r1_subset) 

(5000, 6000, 7000, 8000, 9000, 10000)

In [7]:
c2_r0_subset = Subset(dataset=test_data, indices=c2_r0_indices)
c2_r0_2_subset = Subset(dataset=test_data, indices=c2_r0_2_indices)
c2_r0_4_subset = Subset(dataset=test_data, indices=c2_r0_4_indices)
c2_r0_6_subset = Subset(dataset=test_data, indices=c2_r0_6_indices)
c2_r0_8_subset = Subset(dataset=test_data, indices=c2_r0_8_indices)
c2_r1_subset = Subset(dataset=test_data, indices=c2_r1_indices)

len(c2_r0_subset), len(c2_r0_2_subset), len(c2_r0_4_subset), len(c2_r0_6_subset), len(c2_r0_8_subset), len(c2_r1_subset)

(5000, 6000, 7000, 8000, 9000, 10000)

### Create a DataLoader for each subset

In [8]:
c1_r0_dl = DataLoader(dataset=c1_r0_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c1_r0_2_dl = DataLoader(dataset=c1_r0_2_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c1_r0_4_dl = DataLoader(dataset=c1_r0_4_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c1_r0_6_dl = DataLoader(dataset=c1_r0_6_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c1_r0_8_dl = DataLoader(dataset=c1_r0_8_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c1_r1_dl = DataLoader(dataset=c1_r1_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

len(c1_r0_dl), len(c1_r0_2_dl), len(c1_r0_4_dl), len(c1_r0_6_dl), len(c1_r0_8_dl), len(c1_r1_dl) 

(157, 188, 219, 250, 282, 313)

In [9]:
c2_r0_dl = DataLoader(dataset=c2_r0_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c2_r0_2_dl = DataLoader(dataset=c2_r0_2_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c2_r0_4_dl = DataLoader(dataset=c2_r0_4_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c2_r0_6_dl = DataLoader(dataset=c2_r0_6_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c2_r0_8_dl = DataLoader(dataset=c2_r0_8_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)
c2_r1_dl = DataLoader(dataset=c2_r1_subset, batch_size=32, shuffle=False, num_workers=0, pin_memory=True)

len(c2_r0_dl), len(c2_r0_2_dl), len(c2_r0_4_dl), len(c2_r0_6_dl), len(c2_r0_8_dl), len(c2_r1_dl)

(157, 188, 219, 250, 282, 313)

In [10]:
c2_dls = [c2_r0_dl, c2_r0_2_dl, c2_r0_4_dl, c2_r0_6_dl, c2_r0_8_dl, c2_r1_dl]
c1_dls = [c1_r0_dl, c1_r0_2_dl, c1_r0_4_dl, c1_r0_6_dl, c1_r0_8_dl, c1_r1_dl]

### Model Architectures

In [11]:
class ClientModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3)
        self.conv4 = nn.Conv2d(128, 256, kernel_size=3)
        self.relu = nn.ReLU(inplace=True)
        self.pool = nn.MaxPool2d(kernel_size=2)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(self.relu(self.conv2(x)))
        x = self.relu(self.conv3(x))
        x = self.pool(self.relu(self.conv4(x)))
        return x

class ClientClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(in_features=4*4*256, out_features=10)
        
    def forward(self, x):
        x = self.fc1(torch.flatten(x, 1))
        return x

class ServerModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv5 = nn.Conv2d(256, 512, kernel_size=3)
        self.fc1 = nn.Linear(2*2*512, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU(inplace=True)
        
    def forward(self, x):
        x = self.relu(self.conv5(x))
        x = self.fc1(torch.flatten(x, 1))
        x = self.fc2(x)
        x = self.fc3(x)
        return x

### Accuracy Test for each model combination

In [12]:
general_path = 'fmnist-exp/Baseline_exps/'
experiment_paths = ['Base1/', 'Base2/', 'Base3/', 'Base4/', 'Base5/']

In [13]:
def _load_weights(path):
    client1_model = ClientModel()
    client1_classifier = ClientClassifier()
    client2_model = ClientModel()
    client2_classifier = ClientClassifier()
    server_model = ServerModel()

    client1_model.load_state_dict(torch.load(path + 'Seventh10/client1_model_weights.pt'))
    client2_model.load_state_dict(torch.load(path + 'Seventh10/client2_model_weights.pt'))
    client1_classifier.load_state_dict(torch.load(path + 'Seventh10/client1_classifier_weights.pt'))
    client2_classifier.load_state_dict(torch.load(path + 'Seventh10/client2_classifier_weights.pt'))
    server_model.load_state_dict(torch.load(path + 'Seventh10/server_weights.pt'))

    return client1_model, client1_classifier, client2_model, client2_classifier, server_model

In [15]:
def test_client(model, classifier, test_dl):
    model.eval()
    classifier.eval()
    total = 0
    correct = 0
    with torch.inference_mode():
        for i, data in enumerate(test_dl):
            inputs, labels = data
            model_outputs = model(inputs)
            classifier_outputs = classifier(model_outputs)
            _, predictions = torch.max(classifier_outputs, 1)
            total += labels.size(0)
            correct += (predictions == labels).sum().item()
            del model_outputs, classifier_outputs, inputs
    return round((correct / total)*100, 2)

In [15]:
def _get_test_acc(client1_model, client1_classifier, client2_model, client2_classifier, server_model):
    test_accs = {'client1' : [], 'client1_server' : [], 'client2' : [], 'client2_server' : [], }
    for dl in c1_dls:
        test_accs['client1'].append(test_client(client1_model, client1_classifier, dl))
        test_accs['client1_server'].append(test_client(client1_model, server_model, dl))
    for dl in c2_dls:
        test_accs['client2'].append(test_client(client2_model, client2_classifier, dl))
        test_accs['client2_server'].append(test_client(client2_model, server_model, dl))
    return test_accs

In [16]:
def _iterate_experiments(path):
    client1_model, client1_classifier, client2_model, client2_classifier, server_model = _load_weights(path)
    client1_model.eval()
    client2_model.eval()
    client1_classifier.eval()
    client2_classifier.eval()
    server_model.eval()
    return _get_test_acc(client1_model, client1_classifier, client2_model, client2_classifier, server_model)

In [17]:
exp_dict = {}
for idx, path in enumerate(experiment_paths):
    exp_dict[idx] = _iterate_experiments(general_path + path)

In [19]:
df = pd.DataFrame(exp_dict)
df.to_csv('fmnist-exp/Baseline_exps/test_accs.csv')
df

Unnamed: 0,0,1,2,3,4
client1,"[89.9, 83.53, 79.07, 75.52, 72.63, 70.42]","[86.8, 84.43, 82.56, 81.14, 80.0, 79.28]","[76.84, 78.27, 79.27, 79.66, 80.28, 80.86]","[90.42, 75.77, 65.3, 57.39, 51.17, 46.32]","[89.12, 83.88, 80.21, 77.33, 74.93, 73.22]"
client1_server,"[90.3, 83.28, 78.3, 74.61, 71.44, 69.22]","[90.28, 87.25, 84.83, 83.23, 81.84, 80.69]","[84.44, 84.5, 84.53, 84.39, 84.31, 84.21]","[91.74, 76.5, 65.66, 57.46, 51.11, 46.08]","[89.74, 85.25, 82.06, 79.83, 77.9, 76.42]"
client2,"[95.68, 82.47, 73.1, 65.95, 60.32, 55.8]","[93.74, 84.23, 77.1, 72.1, 67.97, 64.78]","[84.88, 83.92, 82.71, 82.12, 81.42, 80.86]","[95.62, 80.32, 69.66, 61.42, 55.1, 49.97]","[94.32, 83.78, 76.19, 70.7, 66.37, 62.82]"
client2_server,"[96.58, 82.18, 72.14, 64.44, 58.33, 53.55]","[94.9, 85.38, 78.41, 73.47, 69.7, 66.35]","[83.98, 84.07, 83.9, 84.04, 84.18, 84.21]","[96.44, 80.37, 68.89, 60.27, 53.58, 48.22]","[95.22, 84.67, 77.0, 71.5, 67.26, 63.57]"


In [15]:
def test_subset(client_model, classifier, server_model, threshold, OOD_labels, dataloader):
    entropies = []
    total = 0
    correct = 0
    entr_counter = 0
    correct_entr = 0
    with torch.inference_mode():
        for i, data in enumerate(dataloader):
            inputs, labels = data
            client_model_outputs = client_model(inputs)
            outputs = classifier(client_model_outputs)
            preds = outputs.clone().detach()
            pred_distr = nn.functional.softmax(torch.tensor(outputs).clone().detach(), dim=-1)
            entropies.append([(-torch.sum(tensor * torch.log2(tensor))) for i, tensor in enumerate(pred_distr)])
            for idx, entr in enumerate(entropies[-1]):
                if entr > threshold:
                    entr_counter += 1
                    correct_entr += 1 if labels[idx] in OOD_labels else 0
                    preds[idx] =  server_model(client_model_outputs)[idx]
            _, preds = torch.max(preds, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    return round(correct / total * 100, 2), entr_counter, correct_entr


In [16]:
def _iterate_entropy_experiments(path):
    client1_model, client1_classifier, client2_model, client2_classifier, server_model = _load_weights(path)
    client1_model.eval(), client2_model.eval(), client1_classifier.eval()
    client2_classifier.eval(), server_model.eval()
    dict1, dict2 = {}, {}
    thresholds = [0.05, 0.4, 0.8, 1.2, 2.3]
    for thr in thresholds:
        c1, c2 = [], []
        for dl in c1_dls:
            c1.append(test_subset(client1_model, client1_classifier, server_model, threshold=thr, OOD_labels=[5, 6, 7, 8, 9], dataloader=dl))
        for dl in c2_dls:
            c2.append(test_subset(client2_model, client2_classifier, server_model, threshold=thr, OOD_labels=[0, 1, 2, 3, 4], dataloader=dl))
        dict1[thr], dict2[thr] = c1, c2
    return dict1, dict2

In [17]:
exp_dict = {}
for idx, path in enumerate(experiment_paths):
    exp_dict[idx] = _iterate_entropy_experiments(general_path + path)

  pred_distr = nn.functional.softmax(torch.tensor(outputs).clone().detach(), dim=-1)


In [19]:
general_path = 'fmnist-exp/Baseline_exps/'
experiment_paths = ['Base1/', 'Base2/', 'Base3/', 'Base4/', 'Base5/']

In [20]:
for idx, path in enumerate(experiment_paths):
    df1, df2 = pd.DataFrame(data=exp_dict[idx][0]), pd.DataFrame(data=exp_dict[idx][1])
    df1['r'], df2['r'] = [0, 0.2, 0.4, 0.6, 0.8, 1], [0, 0.2, 0.4, 0.6, 0.8, 1]
    df1.to_csv(general_path + path + 'c1_entropies_acc.csv'), df2.to_csv(general_path + path + 'c2_entropies_acc.csv')

In [34]:
pd.read_csv('fmnist-exp\Baseline_exps\Base5\c1_entropies_acc.csv')

Unnamed: 0.1,Unnamed: 0,0.05,0.4,0.8,1.2,2.3,r
0,0,"(89.76, 3578, 0)","(89.74, 1980, 0)","(89.92, 1249, 0)","(89.78, 582, 0)","(89.24, 11, 0)",0.0
1,1,"(85.27, 4497, 919)","(85.25, 2618, 638)","(85.35, 1707, 458)","(84.95, 888, 306)","(83.98, 13, 2)",0.2
2,2,"(82.07, 5427, 1849)","(81.99, 3250, 1270)","(82.16, 2154, 905)","(81.67, 1165, 583)","(80.31, 15, 4)",0.4
3,3,"(79.84, 6357, 2779)","(79.76, 3861, 1881)","(79.86, 2602, 1353)","(79.34, 1449, 867)","(77.45, 22, 11)",0.6
4,4,"(77.92, 7290, 3712)","(77.88, 4503, 2523)","(77.93, 3089, 1840)","(77.32, 1760, 1178)","(75.06, 28, 17)",0.8
5,5,"(76.44, 8219, 4641)","(76.41, 5130, 3150)","(76.51, 3546, 2297)","(75.86, 2054, 1472)","(73.36, 36, 25)",1.0


In [35]:
pd.read_csv('fmnist-exp\Baseline_exps\Base5\c2_entropies_acc.csv')

Unnamed: 0.1,Unnamed: 0,0.05,0.4,0.8,1.2,2.3,r
0,0,"(95.22, 3330, 0)","(95.2, 1231, 0)","(95.18, 624, 0)","(94.82, 220, 0)","(94.34, 5, 0)",0.0
1,1,"(84.67, 4231, 901)","(84.67, 1752, 521)","(84.5, 951, 327)","(84.23, 389, 169)","(83.75, 19, 14)",0.2
2,2,"(77.0, 5109, 1779)","(76.99, 2261, 1030)","(76.8, 1256, 632)","(76.59, 539, 319)","(76.16, 35, 30)",0.4
3,3,"(71.51, 5992, 2662)","(71.46, 2794, 1563)","(71.35, 1589, 965)","(71.04, 696, 476)","(70.65, 46, 41)",0.6
4,4,"(67.27, 6872, 3542)","(67.21, 3309, 2078)","(67.06, 1886, 1262)","(66.67, 829, 609)","(66.3, 48, 43)",0.8
5,5,"(63.58, 7762, 4432)","(63.53, 3825, 2594)","(63.45, 2193, 1569)","(63.1, 988, 768)","(62.77, 52, 47)",1.0
